拦截器-代码审计

  1. /**
  2. * 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
  3. * <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
  4. * <br>
  5. * <p>拦截SQL类型的场景</p>
  6. * <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
  7. * <p>必须使用索引好处,</p>
  8. * <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
  9. * <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
  10. * <br>
  11. * <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
  12. * <p>https://gaoxianglong.github.io/shark</p>
  13. * <p>SQL尽量单表执行的好处</p>
  14. * <p>2.1 查询条件简单、易于开理解和维护;</p>
  15. * <p>2.2 扩展性极强;(可为分库分表做准备)</p>
  16. * <p>2.3 缓存利用率高;</p>
  17. * <p>2.在字段上使用函数</p>
  18. * <br>
  19. * <p>3.where条件为空</p>
  20. * <p>4.where条件使用了 !=</p>
  21. * <p>5.where条件使用了 not 关键字</p>
  22. * <p>6.where条件使用了 or 关键字</p>
  23. * <p>7.where条件使用了 使用子查询</p>
  24. *
  25. * @author willenfoo
  26. * @date 2021年7月22日
  27. * @since 3.4.0
  28. */
  29. public class IllegalSQLInterceptor extends JsqlParserSupport implements InnerInterceptor {
  30. /**
  31. * 缓存验证结果,提高性能
  32. */
  33. private static final Set<String> CACHE_VALID_RESULT = new HashSet<>();
  34. /**
  35. * 缓存表的索引信息
  36. */
  37. private static final Map<String, List<IndexInfo>> INDEX_INFO_MAP = new ConcurrentHashMap<>();
  38. @Override
  39. public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
  40. PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
  41. MappedStatement ms = mpStatementHandler.mappedStatement();
  42. SqlCommandType sct = ms.getSqlCommandType();
  43. if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())) {
  44. return;
  45. }
  46. BoundSql boundSql = mpStatementHandler.boundSql();
  47. String originalSql = boundSql.getSql();
  48. logger.debug("检查SQL是否合规,SQL:" + originalSql);
  49. String md5Base64 = EncryptUtils.md5Base64(originalSql);
  50. if (CACHE_VALID_RESULT.contains(md5Base64)) {
  51. logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
  52. return;
  53. }
  54. parserSingle(originalSql, connection);
  55. //缓存验证结果
  56. CACHE_VALID_RESULT.add(md5Base64);
  57. }
  58. @Override
  59. protected void processSelect(Select select, int index, String sql, Object obj) {
  60. PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
  61. // TODO jialei 2021年9月26日19点52分 这个先过滤分页插件对特殊语句包装子查询检测
  62. if (sql.endsWith("TOTAL")) {
  63. return;
  64. }
  65. Expression where = plainSelect.getWhere();
  66. Assert.notNull(where, "非法SQL,必须要有where条件");
  67. Table table = (Table) plainSelect.getFromItem();
  68. List<Join> joins = plainSelect.getJoins();
  69. validWhere(where, table, (Connection) obj);
  70. validJoins(joins, table, (Connection) obj);
  71. }
  72. @Override
  73. protected void processUpdate(Update update, int index, String sql, Object obj) {
  74. Expression where = update.getWhere();
  75. Assert.notNull(where, "非法SQL,必须要有where条件");
  76. Table table = update.getTable();
  77. List<Join> joins = update.getJoins();
  78. validWhere(where, table, (Connection) obj);
  79. validJoins(joins, table, (Connection) obj);
  80. }
  81. @Override
  82. protected void processDelete(Delete delete, int index, String sql, Object obj) {
  83. Expression where = delete.getWhere();
  84. Assert.notNull(where, "非法SQL,必须要有where条件");
  85. Table table = delete.getTable();
  86. List<Join> joins = delete.getJoins();
  87. validWhere(where, table, (Connection) obj);
  88. validJoins(joins, table, (Connection) obj);
  89. }
  90. /**
  91. * 验证expression对象是不是 or、not等等
  92. *
  93. * @param expression ignore
  94. */
  95. private void validExpression(Expression expression) {
  96. //where条件使用了 or 关键字
  97. if (expression instanceof OrExpression) {
  98. OrExpression orExpression = (OrExpression) expression;
  99. throw new MybatisPlusException("非法SQL,where条件中不能使用【or】关键字,错误or信息:" + orExpression.toString());
  100. } else if (expression instanceof NotEqualsTo) {
  101. NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
  102. throw new MybatisPlusException("非法SQL,where条件中不能使用【!=】关键字,错误!=信息:" + notEqualsTo.toString());
  103. } else if (expression instanceof BinaryExpression) {
  104. BinaryExpression binaryExpression = (BinaryExpression) expression;
  105. if (binaryExpression.getLeftExpression() instanceof Function) {
  106. Function function = (Function) binaryExpression.getLeftExpression();
  107. throw new MybatisPlusException("非法SQL,where条件中不能使用数据库函数,错误函数信息:" + function.toString());
  108. }
  109. if (binaryExpression.getRightExpression() instanceof SubSelect) {
  110. SubSelect subSelect = (SubSelect) binaryExpression.getRightExpression();
  111. // throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
  112. }
  113. } else if (expression instanceof InExpression) {
  114. InExpression inExpression = (InExpression) expression;
  115. if (inExpression.getRightItemsList() instanceof SubSelect) {
  116. SubSelect subSelect = (SubSelect) inExpression.getRightItemsList();
  117. // throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
  118. }
  119. }
  120. }
  121. /**
  122. * 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
  123. *
  124. * @param joins ignore
  125. * @param table ignore
  126. * @param connection ignore
  127. */
  128. private void validJoins(List<Join> joins, Table table, Connection connection) {
  129. //允许执行join,验证jion是否使用索引等等
  130. if (joins != null) {
  131. if (joins.size() > 1) {
  132. throw new MybatisPlusException("非法SQL,超过 2 个表禁止join");
  133. }
  134. for (Join join : joins) {
  135. Table rightTable = (Table) join.getRightItem();
  136. Expression expression = join.getOnExpression();
  137. validWhere(expression, table, rightTable, connection);
  138. }
  139. }
  140. }
  141. /**
  142. * 检查是否使用索引
  143. *
  144. * @param table ignore
  145. * @param columnName ignore
  146. * @param connection ignore
  147. */
  148. private void validUseIndex(Table table, String columnName, Connection connection) {
  149. //是否使用索引
  150. boolean useIndexFlag = false;
  151. String tableInfo = table.getName();
  152. //表存在的索引
  153. String dbName = null;
  154. String tableName;
  155. String[] tableArray = tableInfo.split("\\.");
  156. if (tableArray.length == 1) {
  157. tableName = tableArray[0];
  158. } else {
  159. dbName = tableArray[0];
  160. tableName = tableArray[1];
  161. }
  162. List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
  163. for (IndexInfo indexInfo : indexInfos) {
  164. if (null != columnName && columnName.equalsIgnoreCase(indexInfo.getColumnName())) {
  165. useIndexFlag = true;
  166. break;
  167. }
  168. }
  169. if (!useIndexFlag) {
  170. throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
  171. }
  172. }
  173. /**
  174. * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
  175. *
  176. * @param expression ignore
  177. * @param table ignore
  178. * @param connection ignore
  179. */
  180. private void validWhere(Expression expression, Table table, Connection connection) {
  181. validWhere(expression, table, null, connection);
  182. }
  183. /**
  184. * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
  185. *
  186. * @param expression ignore
  187. * @param table ignore
  188. * @param joinTable ignore
  189. * @param connection ignore
  190. */
  191. private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
  192. validExpression(expression);
  193. if (expression instanceof BinaryExpression) {
  194. //获得左边表达式
  195. Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
  196. validExpression(leftExpression);
  197. //如果左边表达式为Column对象,则直接获得列名
  198. if (leftExpression instanceof Column) {
  199. Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
  200. if (joinTable != null && rightExpression instanceof Column) {
  201. if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
  202. // validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
  203. // validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
  204. } else {
  205. // validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
  206. // validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
  207. }
  208. } else {
  209. //获得列名
  210. // validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
  211. }
  212. }
  213. // 如果BinaryExpression,进行迭代
  214. else if (leftExpression instanceof BinaryExpression) {
  215. validWhere(leftExpression, table, joinTable, connection);
  216. }
  217. //获得右边表达式,并分解
  218. Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
  219. validExpression(rightExpression);
  220. }
  221. }
  222. /**
  223. * 得到表的索引信息
  224. *
  225. * @param dbName ignore
  226. * @param tableName ignore
  227. * @param conn ignore
  228. * @return ignore
  229. */
  230. public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
  231. return getIndexInfos(null, dbName, tableName, conn);
  232. }
  233. /**
  234. * 得到表的索引信息
  235. *
  236. * @param key ignore
  237. * @param dbName ignore
  238. * @param tableName ignore
  239. * @param conn ignore
  240. * @return ignore
  241. */
  242. public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
  243. List<IndexInfo> indexInfos = null;
  244. if (StringUtils.isNotBlank(key)) {
  245. indexInfos = INDEX_INFO_MAP.get(key);
  246. }
  247. if (indexInfos == null || indexInfos.isEmpty()) {
  248. ResultSet rs;
  249. try {
  250. DatabaseMetaData metadata = conn.getMetaData();
  251. String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
  252. String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
  253. rs = metadata.getIndexInfo(catalog, schema, tableName, false, true);
  254. indexInfos = new ArrayList<>();
  255. while (rs.next()) {
  256. //索引中的列序列号等于1,才有效
  257. if (Objects.equals(rs.getString(8), "1")) {
  258. IndexInfo indexInfo = new IndexInfo();
  259. indexInfo.setDbName(rs.getString(1));
  260. indexInfo.setTableName(rs.getString(3));
  261. indexInfo.setColumnName(rs.getString(9));
  262. indexInfos.add(indexInfo);
  263. }
  264. }
  265. if (StringUtils.isNotBlank(key)) {
  266. INDEX_INFO_MAP.put(key, indexInfos);
  267. }
  268. } catch (SQLException e) {
  269. e.printStackTrace();
  270. }
  271. }
  272. return indexInfos;
  273. }
  274. /**
  275. * 索引对象
  276. */
  277. @Data
  278. private static class IndexInfo {
  279. /**
  280. * dbName
  281. */
  282. private String dbName;
  283. /**
  284. * tableName
  285. */
  286. private String tableName;
  287. /**
  288. * columnName
  289. */
  290. private String columnName;
  291. }
  292. }

字段自动填充

  1. public class MyMetaObjectHandler implements MetaObjectHandler {
  2. @Override
  3. public void insertFill(MetaObject metaObject) {
  4. // 判断添加/更新的时候是否给他赋值
  5. Object createDate = getFieldValByName("createDate", metaObject);
  6. LocalDateTime now = LocalDateTime.now();
  7. if (createDate == null) {
  8. this.strictInsertFill(metaObject, "createDate", LocalDateTime.class, now);
  9. }
  10. Object updateDate = getFieldValByName("updateDate", metaObject);
  11. if (updateDate == null) {
  12. this.strictInsertFill(metaObject, "updateDate", LocalDateTime.class, now);
  13. }
  14. Object invalid = getFieldValByName("invalid", metaObject);
  15. if (invalid == null) {
  16. this.strictInsertFill(metaObject, "invalid", Integer.class, 0);
  17. }
  18. this.strictInsertFill(metaObject, "del", Long.class, 0L);
  19. }
  20. @Override
  21. public void updateFill(MetaObject metaObject) {
  22. Object updateDate = getFieldValByName("updateDate", metaObject);
  23. if (updateDate == null) {
  24. this.strictUpdateFill(metaObject, "updateDate", LocalDateTime.class, LocalDateTime.now());
  25. }
  26. }

添加分库建

  1. public enum CustomSqlMethod {
  2. /**
  3. *
  4. */
  5. SELECT_BY_ID_WITH_DISTRIBUTE_KEY("selectByIdWithDistributeKey", "根据ID和分库键查询一条数据", "SELECT %s FROM %s WHERE %s=#{%s} %s %s"),
  6. UPDATE_BY_ID_WITH_DISTRIBUTE_KEY("updateByIdWithDistributeKey", "根据ID 选择修改数据", "<script>\nUPDATE %s %s WHERE %s=#{%s} %s\n</script>"),
  7. ;
  8. private final String method;
  9. private final String desc;
  10. private final String sql;
  11. CustomSqlMethod(String method, String desc, String sql) {
  12. this.method = method;
  13. this.desc = desc;
  14. this.sql = sql;
  15. }
  16. public String getMethod() {
  17. return this.method;
  18. }
  19. public String getDesc() {
  20. return this.desc;
  21. }
  22. public String getSql() {
  23. return this.sql;
  24. }
  25. }
  26. /**
  27. * UpdateByIdWithDistributeKey简介
  28. *
  29. * @author jialei25
  30. * @date 2021-12-09 09:38
  31. */
  32. public class UpdateByIdWithDistributeKey extends AbstractMethod {
  33. public UpdateByIdWithDistributeKey() {
  34. }
  35. @Override
  36. public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
  37. CustomSqlMethod sqlMethod = CustomSqlMethod.UPDATE_BY_ID_WITH_DISTRIBUTE_KEY;
  38. String additional = this.optlockVersion(tableInfo) + tableInfo.getLogicDeleteSql(true, true);
  39. String distributionCodeWhereSql = "";
  40. for (TableFieldInfo tableFieldInfo : tableInfo.getFieldList()) {
  41. if ("distribution_code".equals(tableFieldInfo.getColumn())) {
  42. distributionCodeWhereSql = "distribution_code=#{distributeKey}";
  43. }
  44. }
  45. if (!additional.isEmpty()) {
  46. distributionCodeWhereSql = additional + " AND " + distributionCodeWhereSql;
  47. }
  48. String sql = String.format(sqlMethod.getSql(),
  49. tableInfo.getTableName(),
  50. this.sqlSet(tableInfo.isWithLogicDelete(), false, tableInfo, false, "et", "et."),
  51. tableInfo.getKeyColumn(),
  52. tableInfo.getKeyProperty(),
  53. distributionCodeWhereSql);
  54. SqlSource sqlSource = this.languageDriver.createSqlSource(this.configuration, sql, modelClass);
  55. return this.addUpdateMappedStatement(mapperClass, modelClass, sqlMethod.getMethod(), sqlSource);
  56. }
  57. }

自定义sql 注入

  1. /**
  2. * 自定义Sql注入
  3. *
  4. * @author nieqiurong 2018/8/11 20:23.
  5. */
  6. public class MySqlInjector extends AbstractSqlInjector {
  7. @Override
  8. public List<AbstractMethod> getMethodList(Class<?> mapperClass, TableInfo tableInfo) {
  9. return Stream.of(new Insert(), new Delete(), new DeleteByMap(), new DeleteById(), new DeleteBatchByIds(), new Update(),
  10. new UpdateById(), new UpdateByIdWithDistributeKey(), new SelectById(), new SelectByIdWithDistributeKey(), new SelectBatchByIds(), new SelectByMap(), new SelectOne(), new SelectCount(),
  11. new SelectMaps(), new SelectMapsPage(), new SelectObjs(), new SelectList(), new SelectPage(), new LogicDeleteByIdWithFill())
  12. .collect(Collectors.toList());
  13. }
  14. }

根据分库键删除

  1. package com.jd.irt.store.core.mybatis;
  2. import com.baomidou.mybatisplus.core.mapper.BaseMapper;
  3. import org.apache.ibatis.annotations.Param;
  4. import java.io.Serializable;
  5. /**
  6. * MyBaseMapper <br>
  7. *
  8. * @author jialei25 <br>
  9. * @date 2021-09-27 21:03 <br>
  10. */
  11. public interface MyBaseMapper<T> extends BaseMapper<T> {
  12. /**
  13. * 删除
  14. * @param entity
  15. * @return
  16. */
  17. int deleteByIdWithFill(T entity);
  18. /**
  19. * 根据ID和分库键获取对象
  20. * @param id
  21. * @param distributeKey
  22. * @return
  23. */
  24. T selectByIdWithDistributeKey(@Param("id") Serializable id, @Param("distributeKey") Serializable distributeKey);
  25. /**
  26. * 根据ID分库键更新
  27. * @param entity
  28. * @param id
  29. * @param distributeKey
  30. * @return
  31. */
  32. int updateByIdWithDistributeKey(@Param("et") T entity, @Param("id") Serializable id, @Param("distributeKey") Serializable distributeKey);
  33. }
  34. /*
  35. * Copyright (c) 2011-2021, baomidou (jobob@qq.com).
  36. *
  37. /**
  38. * 根据 id 逻辑删除数据,并带字段填充功能
  39. * <p>注意入参是 entity !!! ,如果字段没有自动填充,就只是单纯的逻辑删除</p>
  40. * <p>
  41. * 自己的通用 mapper 如下使用:
  42. * <pre>
  43. * int deleteByIdWithFill(T entity);
  44. * </pre>
  45. * </p>
  46. *
  47. * @author miemie
  48. * @since 2018-11-09
  49. */
  50. @SuppressWarnings("serial")
  51. public class LogicDeleteByIdWithFill extends AbstractMethod {
  52. @Override
  53. public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
  54. String sql;
  55. SqlMethod sqlMethod = SqlMethod.LOGIC_DELETE_BY_ID;
  56. String distributionCodeWhereSql = "";
  57. String soleCodeSetSql = "";
  58. List<TableFieldInfo> updateFieldInfoList = new ArrayList<>();
  59. for (TableFieldInfo tableFieldInfo : tableInfo.getFieldList()) {
  60. if (tableFieldInfo.isWithUpdateFill()) {
  61. updateFieldInfoList.add(tableFieldInfo);
  62. }
  63. if ("distribution_code".equals(tableFieldInfo.getColumn())) {
  64. distributionCodeWhereSql = " and distribution_code = #{distributionCode}";
  65. }
  66. if ("sole_code".equals(tableFieldInfo.getColumn())) {
  67. soleCodeSetSql = "sole_code = uuid(), ";
  68. }
  69. }
  70. if (tableInfo.isWithLogicDelete()) {
  71. if (CollectionUtils.isNotEmpty(updateFieldInfoList)) {
  72. String sqlSet = "SET " + soleCodeSetSql + updateFieldInfoList.stream().map(i -> i.getSqlSet(EMPTY)).collect(joining(EMPTY))
  73. + tableInfo.getLogicDeleteSql(false, false);
  74. sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), sqlSet, tableInfo.getKeyColumn(),
  75. tableInfo.getKeyProperty(), distributionCodeWhereSql + tableInfo.getLogicDeleteSql(true, true));
  76. } else {
  77. String sqlSet = "SET " + soleCodeSetSql
  78. + tableInfo.getLogicDeleteSql(false, false);
  79. sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), sqlSet,
  80. tableInfo.getKeyColumn(), tableInfo.getKeyProperty(),
  81. distributionCodeWhereSql + tableInfo.getLogicDeleteSql(true, true));
  82. }
  83. } else {
  84. sqlMethod = SqlMethod.DELETE_BY_ID;
  85. sql = String.format("<script>\nDELETE FROM %s WHERE %s=#{%s} %s\n</script>", tableInfo.getTableName(), tableInfo.getKeyColumn(),
  86. tableInfo.getKeyProperty(), distributionCodeWhereSql);
  87. }
  88. SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
  89. return addUpdateMappedStatement(mapperClass, modelClass, getMethod(sqlMethod), sqlSource);
  90. }
  91. @Override
  92. public String getMethod(SqlMethod sqlMethod) {
  93. // 自定义 mapper 方法名
  94. return "deleteByIdWithFill";
  95. }
  96. }