其中包含如下几个要点:获取mappedStatement、boundSql和重置mappedStatement。

    1. /*
    2. * Copyright (c) 2018-2028, Chill Zhuang All rights reserved.
    3. *
    4. * Redistribution and use in source and binary forms, with or without
    5. * modification, are permitted provided that the following conditions are met:
    6. *
    7. * Redistributions of source code must retain the above copyright notice,
    8. * this list of conditions and the following disclaimer.
    9. * Redistributions in binary form must reproduce the above copyright
    10. * notice, this list of conditions and the following disclaimer in the
    11. * documentation and/or other materials provided with the distribution.
    12. * Neither the name of the dreamlu.net developer nor the names of its
    13. * contributors may be used to endorse or promote products derived from
    14. * this software without specific prior written permission.
    15. * Author: Chill 庄骞 (smallchill@163.com)
    16. */
    17. package org.springblade.core.mp.plugins;
    18. import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
    19. import lombok.extern.slf4j.Slf4j;
    20. import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
    21. import net.sf.jsqlparser.parser.CCJSqlParserUtil;
    22. import net.sf.jsqlparser.schema.Column;
    23. import net.sf.jsqlparser.schema.Table;
    24. import net.sf.jsqlparser.statement.Statement;
    25. import net.sf.jsqlparser.statement.select.Join;
    26. import net.sf.jsqlparser.statement.select.PlainSelect;
    27. import net.sf.jsqlparser.statement.select.Select;
    28. import net.sf.jsqlparser.statement.select.SelectBody;
    29. import org.apache.ibatis.cache.CacheKey;
    30. import org.apache.ibatis.executor.Executor;
    31. import org.apache.ibatis.mapping.BoundSql;
    32. import org.apache.ibatis.mapping.MappedStatement;
    33. import org.apache.ibatis.mapping.SqlCommandType;
    34. import org.apache.ibatis.mapping.SqlSource;
    35. import org.apache.ibatis.plugin.Interceptor;
    36. import org.apache.ibatis.plugin.Intercepts;
    37. import org.apache.ibatis.plugin.Invocation;
    38. import org.apache.ibatis.plugin.Signature;
    39. import org.apache.ibatis.reflection.DefaultReflectorFactory;
    40. import org.apache.ibatis.reflection.MetaObject;
    41. import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
    42. import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
    43. import org.apache.ibatis.session.ResultHandler;
    44. import org.apache.ibatis.session.RowBounds;
    45. import org.springblade.core.mp.annotation.SqlAssets;
    46. import org.springframework.stereotype.Component;
    47. import java.lang.reflect.Field;
    48. import java.lang.reflect.Method;
    49. import java.sql.SQLException;
    50. import java.util.Collections;
    51. import java.util.Properties;
    52. /**
    53. * 租户拦截器
    54. *
    55. * @author Chill
    56. */
    57. @Slf4j
    58. @Intercepts({
    59. @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
    60. @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
    61. })
    62. @Component
    63. public class CdosAssetInterceptor extends JsqlParserSupport implements Interceptor {
    64. @Override
    65. public Object intercept(Invocation invocation) throws Throwable {
    66. // 获取拦截的参数
    67. Object[] args = invocation.getArgs();
    68. MappedStatement mappedStatement = (MappedStatement) args[0];
    69. // 如果不是select方法,则继续执行
    70. if (!SqlCommandType.SELECT.equals(mappedStatement.getSqlCommandType())) {
    71. return invocation.proceed();
    72. }
    73. // 获取执行方法的包名和方法名
    74. String namespace = mappedStatement.getId();
    75. String className = namespace.substring(0, namespace.lastIndexOf("."));
    76. String methodName = namespace.substring(namespace.lastIndexOf(".") + 1, namespace.length());
    77. Method[] methods = Class.forName(className).getMethods();
    78. SqlAssets annotation = null;
    79. for (Method method : methods) {
    80. if (method.getName().equals(methodName)) {
    81. //获取注解 来判断是不是要储存sql
    82. annotation = method.getAnnotation(SqlAssets.class);
    83. break;
    84. }
    85. }
    86. if (annotation != null) {
    87. // 获得关联的参数
    88. String tableName = annotation.tableName();
    89. String fieldName = annotation.fieldName();
    90. // 获取调用方法的参数列表
    91. Object parameter = null;
    92. if (args.length > 1) {
    93. parameter = args[1];
    94. }
    95. BoundSql boundSql = mappedStatement.getBoundSql(parameter);
    96. // PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
    97. String originalSql = boundSql.getSql();
    98. Object parameterObject = boundSql.getParameterObject();
    99. // mpBs.sql(parserSingle(mpBs.sql(), null));
    100. log.info("");
    101. if (parameterObject != null) {
    102. // 解析Sql
    103. Statement statement = CCJSqlParserUtil.parse(originalSql);
    104. SelectBody selectBody = ((Select) statement).getSelectBody();
    105. PlainSelect plainSelect = (PlainSelect) selectBody;
    106. // todo: 字段冲突,新加的tb_asset表如果和旧表字段冲突怎么办
    107. // 修改joins
    108. Table assetTable = new Table("tb_asset");
    109. Table oldTable = new Table(tableName);
    110. Join join = new Join();
    111. join.setInner(true);
    112. join.setRightItem(assetTable);
    113. EqualsTo equalsTo = new EqualsTo();
    114. equalsTo.setLeftExpression(new Column(assetTable, "id"));
    115. equalsTo.setRightExpression(new Column(oldTable, fieldName));
    116. join.setOnExpression(equalsTo);
    117. if (plainSelect.getJoins() != null) {
    118. plainSelect.getJoins().add(join);
    119. } else {
    120. plainSelect.setJoins(Collections.singletonList(join));
    121. }
    122. String newSql = statement.toString();
    123. resetSql2Invocation(invocation, newSql);
    124. }
    125. }
    126. // List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
    127. // Class<?>[] parameterClasses = parameterMappings.parallelStream().map(ParameterMapping::getJavaType).toArray(Class<?>[]::new);
    128. // // 获取方法
    129. // Class<?> aClass = Class.forName(className);
    130. // Method method = aClass.getMethod(methodName, parameterClasses);
    131. return invocation.proceed();
    132. }
    133. @Override
    134. public Object plugin(Object target) {
    135. return Interceptor.super.plugin(target);
    136. }
    137. @Override
    138. public void setProperties(Properties properties) {
    139. Interceptor.super.setProperties(properties);
    140. }
    141. /**
    142. * 包装sql后,重置到invocation中
    143. *
    144. * @param invocation
    145. * @param sql
    146. * @throws SQLException
    147. */
    148. private void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
    149. final Object[] args = invocation.getArgs();
    150. MappedStatement statement = (MappedStatement) args[0];
    151. Object parameterObject = args[1];
    152. BoundSql boundSql = statement.getBoundSql(parameterObject);
    153. MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
    154. MetaObject msObject = MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(), new DefaultReflectorFactory());
    155. msObject.setValue("sqlSource.boundSql.sql", sql);
    156. args[0] = newStatement;
    157. }
    158. private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
    159. MappedStatement.Builder builder =
    160. new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
    161. builder.resource(ms.getResource());
    162. builder.fetchSize(ms.getFetchSize());
    163. builder.statementType(ms.getStatementType());
    164. builder.keyGenerator(ms.getKeyGenerator());
    165. if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
    166. StringBuilder keyProperties = new StringBuilder();
    167. for (String keyProperty : ms.getKeyProperties()) {
    168. keyProperties.append(keyProperty).append(",");
    169. }
    170. keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
    171. builder.keyProperty(keyProperties.toString());
    172. }
    173. builder.timeout(ms.getTimeout());
    174. builder.parameterMap(ms.getParameterMap());
    175. builder.resultMaps(ms.getResultMaps());
    176. builder.resultSetType(ms.getResultSetType());
    177. builder.cache(ms.getCache());
    178. builder.flushCacheRequired(ms.isFlushCacheRequired());
    179. builder.useCache(ms.isUseCache());
    180. return builder.build();
    181. }
    182. // 定义一个内部辅助类,作用是包装sql
    183. class BoundSqlSqlSource implements SqlSource {
    184. private BoundSql boundSql;
    185. public BoundSqlSqlSource(BoundSql boundSql) {
    186. this.boundSql = boundSql;
    187. }
    188. @Override
    189. public BoundSql getBoundSql(Object parameterObject) {
    190. return boundSql;
    191. }
    192. }
    193. }