拦截器-代码审计
/**
* 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
* <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
* <br>
* <p>拦截SQL类型的场景</p>
* <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
* <p>必须使用索引好处,</p>
* <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
* <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
* <br>
* <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
* <p>https://gaoxianglong.github.io/shark</p>
* <p>SQL尽量单表执行的好处</p>
* <p>2.1 查询条件简单、易于开理解和维护;</p>
* <p>2.2 扩展性极强;(可为分库分表做准备)</p>
* <p>2.3 缓存利用率高;</p>
* <p>2.在字段上使用函数</p>
* <br>
* <p>3.where条件为空</p>
* <p>4.where条件使用了 !=</p>
* <p>5.where条件使用了 not 关键字</p>
* <p>6.where条件使用了 or 关键字</p>
* <p>7.where条件使用了 使用子查询</p>
*
* @author willenfoo
* @date 2021年7月22日
* @since 3.4.0
*/
public class IllegalSQLInterceptor extends JsqlParserSupport implements InnerInterceptor {
/**
* 缓存验证结果,提高性能
*/
private static final Set<String> CACHE_VALID_RESULT = new HashSet<>();
/**
* 缓存表的索引信息
*/
private static final Map<String, List<IndexInfo>> INDEX_INFO_MAP = new ConcurrentHashMap<>();
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpStatementHandler.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())) {
return;
}
BoundSql boundSql = mpStatementHandler.boundSql();
String originalSql = boundSql.getSql();
logger.debug("检查SQL是否合规,SQL:" + originalSql);
String md5Base64 = EncryptUtils.md5Base64(originalSql);
if (CACHE_VALID_RESULT.contains(md5Base64)) {
logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
return;
}
parserSingle(originalSql, connection);
//缓存验证结果
CACHE_VALID_RESULT.add(md5Base64);
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
// TODO jialei 2021年9月26日19点52分 这个先过滤分页插件对特殊语句包装子查询检测
if (sql.endsWith("TOTAL")) {
return;
}
Expression where = plainSelect.getWhere();
Assert.notNull(where, "非法SQL,必须要有where条件");
Table table = (Table) plainSelect.getFromItem();
List<Join> joins = plainSelect.getJoins();
validWhere(where, table, (Connection) obj);
validJoins(joins, table, (Connection) obj);
}
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
Expression where = update.getWhere();
Assert.notNull(where, "非法SQL,必须要有where条件");
Table table = update.getTable();
List<Join> joins = update.getJoins();
validWhere(where, table, (Connection) obj);
validJoins(joins, table, (Connection) obj);
}
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
Expression where = delete.getWhere();
Assert.notNull(where, "非法SQL,必须要有where条件");
Table table = delete.getTable();
List<Join> joins = delete.getJoins();
validWhere(where, table, (Connection) obj);
validJoins(joins, table, (Connection) obj);
}
/**
* 验证expression对象是不是 or、not等等
*
* @param expression ignore
*/
private void validExpression(Expression expression) {
//where条件使用了 or 关键字
if (expression instanceof OrExpression) {
OrExpression orExpression = (OrExpression) expression;
throw new MybatisPlusException("非法SQL,where条件中不能使用【or】关键字,错误or信息:" + orExpression.toString());
} else if (expression instanceof NotEqualsTo) {
NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
throw new MybatisPlusException("非法SQL,where条件中不能使用【!=】关键字,错误!=信息:" + notEqualsTo.toString());
} else if (expression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) expression;
if (binaryExpression.getLeftExpression() instanceof Function) {
Function function = (Function) binaryExpression.getLeftExpression();
throw new MybatisPlusException("非法SQL,where条件中不能使用数据库函数,错误函数信息:" + function.toString());
}
if (binaryExpression.getRightExpression() instanceof SubSelect) {
SubSelect subSelect = (SubSelect) binaryExpression.getRightExpression();
// throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
}
} else if (expression instanceof InExpression) {
InExpression inExpression = (InExpression) expression;
if (inExpression.getRightItemsList() instanceof SubSelect) {
SubSelect subSelect = (SubSelect) inExpression.getRightItemsList();
// throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
}
}
}
/**
* 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
*
* @param joins ignore
* @param table ignore
* @param connection ignore
*/
private void validJoins(List<Join> joins, Table table, Connection connection) {
//允许执行join,验证jion是否使用索引等等
if (joins != null) {
if (joins.size() > 1) {
throw new MybatisPlusException("非法SQL,超过 2 个表禁止join");
}
for (Join join : joins) {
Table rightTable = (Table) join.getRightItem();
Expression expression = join.getOnExpression();
validWhere(expression, table, rightTable, connection);
}
}
}
/**
* 检查是否使用索引
*
* @param table ignore
* @param columnName ignore
* @param connection ignore
*/
private void validUseIndex(Table table, String columnName, Connection connection) {
//是否使用索引
boolean useIndexFlag = false;
String tableInfo = table.getName();
//表存在的索引
String dbName = null;
String tableName;
String[] tableArray = tableInfo.split("\\.");
if (tableArray.length == 1) {
tableName = tableArray[0];
} else {
dbName = tableArray[0];
tableName = tableArray[1];
}
List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
for (IndexInfo indexInfo : indexInfos) {
if (null != columnName && columnName.equalsIgnoreCase(indexInfo.getColumnName())) {
useIndexFlag = true;
break;
}
}
if (!useIndexFlag) {
throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
}
}
/**
* 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
*
* @param expression ignore
* @param table ignore
* @param connection ignore
*/
private void validWhere(Expression expression, Table table, Connection connection) {
validWhere(expression, table, null, connection);
}
/**
* 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
*
* @param expression ignore
* @param table ignore
* @param joinTable ignore
* @param connection ignore
*/
private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
validExpression(expression);
if (expression instanceof BinaryExpression) {
//获得左边表达式
Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
validExpression(leftExpression);
//如果左边表达式为Column对象,则直接获得列名
if (leftExpression instanceof Column) {
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
if (joinTable != null && rightExpression instanceof Column) {
if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
// validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
// validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
} else {
// validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
// validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
}
} else {
//获得列名
// validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
}
}
// 如果BinaryExpression,进行迭代
else if (leftExpression instanceof BinaryExpression) {
validWhere(leftExpression, table, joinTable, connection);
}
//获得右边表达式,并分解
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
validExpression(rightExpression);
}
}
/**
* 得到表的索引信息
*
* @param dbName ignore
* @param tableName ignore
* @param conn ignore
* @return ignore
*/
public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
return getIndexInfos(null, dbName, tableName, conn);
}
/**
* 得到表的索引信息
*
* @param key ignore
* @param dbName ignore
* @param tableName ignore
* @param conn ignore
* @return ignore
*/
public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
List<IndexInfo> indexInfos = null;
if (StringUtils.isNotBlank(key)) {
indexInfos = INDEX_INFO_MAP.get(key);
}
if (indexInfos == null || indexInfos.isEmpty()) {
ResultSet rs;
try {
DatabaseMetaData metadata = conn.getMetaData();
String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
rs = metadata.getIndexInfo(catalog, schema, tableName, false, true);
indexInfos = new ArrayList<>();
while (rs.next()) {
//索引中的列序列号等于1,才有效
if (Objects.equals(rs.getString(8), "1")) {
IndexInfo indexInfo = new IndexInfo();
indexInfo.setDbName(rs.getString(1));
indexInfo.setTableName(rs.getString(3));
indexInfo.setColumnName(rs.getString(9));
indexInfos.add(indexInfo);
}
}
if (StringUtils.isNotBlank(key)) {
INDEX_INFO_MAP.put(key, indexInfos);
}
} catch (SQLException e) {
e.printStackTrace();
}
}
return indexInfos;
}
/**
* 索引对象
*/
@Data
private static class IndexInfo {
/**
* dbName
*/
private String dbName;
/**
* tableName
*/
private String tableName;
/**
* columnName
*/
private String columnName;
}
}
字段自动填充
public class MyMetaObjectHandler implements MetaObjectHandler {
@Override
public void insertFill(MetaObject metaObject) {
// 判断添加/更新的时候是否给他赋值
Object createDate = getFieldValByName("createDate", metaObject);
LocalDateTime now = LocalDateTime.now();
if (createDate == null) {
this.strictInsertFill(metaObject, "createDate", LocalDateTime.class, now);
}
Object updateDate = getFieldValByName("updateDate", metaObject);
if (updateDate == null) {
this.strictInsertFill(metaObject, "updateDate", LocalDateTime.class, now);
}
Object invalid = getFieldValByName("invalid", metaObject);
if (invalid == null) {
this.strictInsertFill(metaObject, "invalid", Integer.class, 0);
}
this.strictInsertFill(metaObject, "del", Long.class, 0L);
}
@Override
public void updateFill(MetaObject metaObject) {
Object updateDate = getFieldValByName("updateDate", metaObject);
if (updateDate == null) {
this.strictUpdateFill(metaObject, "updateDate", LocalDateTime.class, LocalDateTime.now());
}
}
添加分库建
public enum CustomSqlMethod {
/**
*
*/
SELECT_BY_ID_WITH_DISTRIBUTE_KEY("selectByIdWithDistributeKey", "根据ID和分库键查询一条数据", "SELECT %s FROM %s WHERE %s=#{%s} %s %s"),
UPDATE_BY_ID_WITH_DISTRIBUTE_KEY("updateByIdWithDistributeKey", "根据ID 选择修改数据", "<script>\nUPDATE %s %s WHERE %s=#{%s} %s\n</script>"),
;
private final String method;
private final String desc;
private final String sql;
CustomSqlMethod(String method, String desc, String sql) {
this.method = method;
this.desc = desc;
this.sql = sql;
}
public String getMethod() {
return this.method;
}
public String getDesc() {
return this.desc;
}
public String getSql() {
return this.sql;
}
}
/**
* UpdateByIdWithDistributeKey简介
*
* @author jialei25
* @date 2021-12-09 09:38
*/
public class UpdateByIdWithDistributeKey extends AbstractMethod {
public UpdateByIdWithDistributeKey() {
}
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
CustomSqlMethod sqlMethod = CustomSqlMethod.UPDATE_BY_ID_WITH_DISTRIBUTE_KEY;
String additional = this.optlockVersion(tableInfo) + tableInfo.getLogicDeleteSql(true, true);
String distributionCodeWhereSql = "";
for (TableFieldInfo tableFieldInfo : tableInfo.getFieldList()) {
if ("distribution_code".equals(tableFieldInfo.getColumn())) {
distributionCodeWhereSql = "distribution_code=#{distributeKey}";
}
}
if (!additional.isEmpty()) {
distributionCodeWhereSql = additional + " AND " + distributionCodeWhereSql;
}
String sql = String.format(sqlMethod.getSql(),
tableInfo.getTableName(),
this.sqlSet(tableInfo.isWithLogicDelete(), false, tableInfo, false, "et", "et."),
tableInfo.getKeyColumn(),
tableInfo.getKeyProperty(),
distributionCodeWhereSql);
SqlSource sqlSource = this.languageDriver.createSqlSource(this.configuration, sql, modelClass);
return this.addUpdateMappedStatement(mapperClass, modelClass, sqlMethod.getMethod(), sqlSource);
}
}
自定义sql 注入
/**
* 自定义Sql注入
*
* @author nieqiurong 2018/8/11 20:23.
*/
public class MySqlInjector extends AbstractSqlInjector {
@Override
public List<AbstractMethod> getMethodList(Class<?> mapperClass, TableInfo tableInfo) {
return Stream.of(new Insert(), new Delete(), new DeleteByMap(), new DeleteById(), new DeleteBatchByIds(), new Update(),
new UpdateById(), new UpdateByIdWithDistributeKey(), new SelectById(), new SelectByIdWithDistributeKey(), new SelectBatchByIds(), new SelectByMap(), new SelectOne(), new SelectCount(),
new SelectMaps(), new SelectMapsPage(), new SelectObjs(), new SelectList(), new SelectPage(), new LogicDeleteByIdWithFill())
.collect(Collectors.toList());
}
}
根据分库键删除
package com.jd.irt.store.core.mybatis;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Param;
import java.io.Serializable;
/**
* MyBaseMapper <br>
*
* @author jialei25 <br>
* @date 2021-09-27 21:03 <br>
*/
public interface MyBaseMapper<T> extends BaseMapper<T> {
/**
* 删除
* @param entity
* @return
*/
int deleteByIdWithFill(T entity);
/**
* 根据ID和分库键获取对象
* @param id
* @param distributeKey
* @return
*/
T selectByIdWithDistributeKey(@Param("id") Serializable id, @Param("distributeKey") Serializable distributeKey);
/**
* 根据ID分库键更新
* @param entity
* @param id
* @param distributeKey
* @return
*/
int updateByIdWithDistributeKey(@Param("et") T entity, @Param("id") Serializable id, @Param("distributeKey") Serializable distributeKey);
}
/*
* Copyright (c) 2011-2021, baomidou (jobob@qq.com).
*
/**
* 根据 id 逻辑删除数据,并带字段填充功能
* <p>注意入参是 entity !!! ,如果字段没有自动填充,就只是单纯的逻辑删除</p>
* <p>
* 自己的通用 mapper 如下使用:
* <pre>
* int deleteByIdWithFill(T entity);
* </pre>
* </p>
*
* @author miemie
* @since 2018-11-09
*/
@SuppressWarnings("serial")
public class LogicDeleteByIdWithFill extends AbstractMethod {
@Override
public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
String sql;
SqlMethod sqlMethod = SqlMethod.LOGIC_DELETE_BY_ID;
String distributionCodeWhereSql = "";
String soleCodeSetSql = "";
List<TableFieldInfo> updateFieldInfoList = new ArrayList<>();
for (TableFieldInfo tableFieldInfo : tableInfo.getFieldList()) {
if (tableFieldInfo.isWithUpdateFill()) {
updateFieldInfoList.add(tableFieldInfo);
}
if ("distribution_code".equals(tableFieldInfo.getColumn())) {
distributionCodeWhereSql = " and distribution_code = #{distributionCode}";
}
if ("sole_code".equals(tableFieldInfo.getColumn())) {
soleCodeSetSql = "sole_code = uuid(), ";
}
}
if (tableInfo.isWithLogicDelete()) {
if (CollectionUtils.isNotEmpty(updateFieldInfoList)) {
String sqlSet = "SET " + soleCodeSetSql + updateFieldInfoList.stream().map(i -> i.getSqlSet(EMPTY)).collect(joining(EMPTY))
+ tableInfo.getLogicDeleteSql(false, false);
sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), sqlSet, tableInfo.getKeyColumn(),
tableInfo.getKeyProperty(), distributionCodeWhereSql + tableInfo.getLogicDeleteSql(true, true));
} else {
String sqlSet = "SET " + soleCodeSetSql
+ tableInfo.getLogicDeleteSql(false, false);
sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), sqlSet,
tableInfo.getKeyColumn(), tableInfo.getKeyProperty(),
distributionCodeWhereSql + tableInfo.getLogicDeleteSql(true, true));
}
} else {
sqlMethod = SqlMethod.DELETE_BY_ID;
sql = String.format("<script>\nDELETE FROM %s WHERE %s=#{%s} %s\n</script>", tableInfo.getTableName(), tableInfo.getKeyColumn(),
tableInfo.getKeyProperty(), distributionCodeWhereSql);
}
SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
return addUpdateMappedStatement(mapperClass, modelClass, getMethod(sqlMethod), sqlSource);
}
@Override
public String getMethod(SqlMethod sqlMethod) {
// 自定义 mapper 方法名
return "deleteByIdWithFill";
}
}