简介

利用 MyBatis Plugin 插件技术实现分页功能。

分页插件实现思路如下:

  • 业务代码在 ThreadLocal 中保存分页信息;
  • MyBatis Interceptor 拦截查询请求,获取分页信息,实现分页操作,封装分页列表数据返回;

测试类:com.yjw.demo.PageTest

插件开发过程

确定需要拦截的签名

MyBatis 插件可以拦截四大对象中的任意一个,从 Plugin 源码中可以看到它需要注册签名才能够运行插件,签名需要确定一些要素。

确定需要拦截的对象

  • Executor 是执行 SQL 的全过程,包括组装参数,组装结果集返回和执行 SQL 过程,都可以拦截。
  • StatementHandler 是执行 SQL 的过程,我们可以重写执行 SQL 的过程。
  • ParameterHandler 是拦截执行 SQL 的参数组装,我们可以重写组装参数规则。
  • ResultSetHandler 用于拦截执行结果的组装,我们可以重写组装结果的规则。

拦截方法和参数

当确定了需要拦截什么对象,接下来就要确定需要拦截什么方法和方法的参数。比如分页插件需要拦截 Executor 的 query 方法,我们先看看 Executor 接口的定义,代码清单如下:

  1. public interface Executor {
  2. ResultHandler NO_RESULT_HANDLER = null;
  3. int update(MappedStatement ms, Object parameter) throws SQLException;
  4. <E> List<E> query(MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws SQLException;
  5. <E> List<E> query(MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler) throws SQLException;
  6. <E> Cursor<E> queryCursor(MappedStatement ms, Object parameter, RowBounds rowBounds) throws SQLException;
  7. List<BatchResult> flushStatements() throws SQLException;
  8. void commit(boolean required) throws SQLException;
  9. void rollback(boolean required) throws SQLException;
  10. CacheKey createCacheKey(MappedStatement ms, Object parameterObject, RowBounds rowBounds, BoundSql boundSql);
  11. boolean isCached(MappedStatement ms, CacheKey key);
  12. void clearLocalCache();
  13. void deferLoad(MappedStatement ms, MetaObject resultObject, String property, CacheKey key, Class<?> targetType);
  14. Transaction getTransaction();
  15. void close(boolean forceRollback);
  16. boolean isClosed();
  17. void setExecutorWrapper(Executor executor);
  18. }

以上的任何方法都可以拦截,从接口定义而言,query 方法有两个,我们可以按照代码清单来定义签名。

  1. @Intercepts({
  2. @Signature(type = Executor.class, method = "query",
  3. args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
  4. @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class,
  5. ResultHandler.class, CacheKey.class, BoundSql.class})})

其中,@Intercepts 说明它是一个拦截器。@Signature 是注册拦截器签名的地方,type 是四大对象中的一个,method 是需要拦截的方法,args 是方法的参数。

插件接口定义

在 MyBatis 中开发插件,需要实现 Interceptor 接口,接口的定义如下:

  1. public interface Interceptor {
  2. Object intercept(Invocation invocation) throws Throwable;
  3. Object plugin(Object target);
  4. void setProperties(Properties properties);
  5. }
  • intercept 方法:它将直接覆盖你所拦截对象原有的方法,因此它是插件的核心方法。通过 invocation 参数可以反射调度原来对象的方法。
  • plugin 方法:target 是被拦截对象,它的作用是给被拦截对象生成一个代理对象,并返回它。为了方便 MyBatis 使用 org.apache.ibatis.plugin.Plugin 中的 wrap 静态方法提供生成代理对象。
  • setProperties 方法:允许在 plugin 元素中配置所需参数,方法在插件初始化的时候就被调用了一次,然后把插件对象存入到配置中,以便后面再取出。

实现类

根据分页插件的实现思路,定义了三个类。

Page 类

Page 类继承了 ArrayList 类,用来封装分页信息和列表数据。

  1. /**
  2. * 分页返回对象
  3. *
  4. * @author yinjianwei
  5. * @date 2018/11/05
  6. */
  7. public class Page<E> extends ArrayList<E> {
  8. private static final long serialVersionUID = 1L;
  9. /**
  10. * 页码,从1开始
  11. */
  12. private int pageNum;
  13. /**
  14. * 页面大小
  15. */
  16. private int pageSize;
  17. /**
  18. * 起始行
  19. */
  20. private int startRow;
  21. /**
  22. * 末行
  23. */
  24. private int endRow;
  25. /**
  26. * 总数
  27. */
  28. private long total;
  29. /**
  30. * 总页数
  31. */
  32. private int pages;
  33. public int getPageNum() {
  34. return pageNum;
  35. }
  36. public void setPageNum(int pageNum) {
  37. this.pageNum = pageNum;
  38. }
  39. public int getPageSize() {
  40. return pageSize;
  41. }
  42. public void setPageSize(int pageSize) {
  43. this.pageSize = pageSize;
  44. }
  45. public int getStartRow() {
  46. return startRow;
  47. }
  48. public void setStartRow(int startRow) {
  49. this.startRow = startRow;
  50. }
  51. public int getEndRow() {
  52. return endRow;
  53. }
  54. public void setEndRow(int endRow) {
  55. this.endRow = endRow;
  56. }
  57. public long getTotal() {
  58. return total;
  59. }
  60. public void setTotal(long total) {
  61. this.total = total;
  62. this.pages = (int)(total / pageSize + (total % pageSize == 0 ? 0 : 1));
  63. if (pageNum > pages) {
  64. pageNum = pages;
  65. }
  66. this.startRow = this.pageNum > 0 ? (this.pageNum - 1) * this.pageSize : 0;
  67. this.endRow = this.startRow + this.pageSize * (this.pageNum > 0 ? 1 : 0);
  68. }
  69. public int getPages() {
  70. return pages;
  71. }
  72. public void setPages(int pages) {
  73. this.pages = pages;
  74. }
  75. /**
  76. * 返回当前对象
  77. *
  78. * @return
  79. */
  80. public List<E> getResult() {
  81. return this;
  82. }
  83. }

PageHelper 类

PageHelper 类是分页的帮助类,主要利用 ThreadLocal 线程变量存储分页信息。代码清单如下:

  1. /**
  2. * 分页帮助类
  3. *
  4. * @author yinjianwei
  5. * @date 2018/11/05
  6. */
  7. @SuppressWarnings("rawtypes")
  8. public class PageHelper {
  9. private static final ThreadLocal<Page> PAGE_THREADLOCAT = new ThreadLocal<Page>();
  10. /**
  11. * 设置线程局部变量分页信息
  12. *
  13. * @param page
  14. */
  15. public static void setPageThreadLocal(Page page) {
  16. PAGE_THREADLOCAT.set(page);
  17. }
  18. /**
  19. * 获取线程局部变量分页信息
  20. *
  21. * @return
  22. */
  23. public static Page getPageThreadLocal() {
  24. return PAGE_THREADLOCAT.get();
  25. }
  26. /**
  27. * 清空线程局部变量分页信息
  28. */
  29. public static void pageThreadLocalClear() {
  30. PAGE_THREADLOCAT.remove();
  31. }
  32. /**
  33. * 设置分页参数
  34. *
  35. * @param pageNum
  36. * @param pageSize
  37. */
  38. public static void startPage(Integer pageNum, Integer pageSize) {
  39. Page page = new Page();
  40. page.setPageNum(pageNum);
  41. page.setPageSize(pageSize);
  42. setPageThreadLocal(page);
  43. }
  44. }

PageInterceptor 类

PageInterceptor 类实现了 Interceptor 接口,是分页插件的核心类。代码清单如下:

  1. /**
  2. * 分页拦截器
  3. *
  4. * @author yinjianwei
  5. * @date 2018/11/05
  6. */
  7. @Intercepts({
  8. @Signature(type = Executor.class, method = "query",
  9. args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
  10. @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class,
  11. ResultHandler.class, CacheKey.class, BoundSql.class})})
  12. public class PageInterceptor implements Interceptor {
  13. private Field additionalParametersField;
  14. @SuppressWarnings({"rawtypes", "unchecked"})
  15. @Override
  16. public Object intercept(Invocation invocation) throws Throwable {
  17. Executor executor = (Executor)invocation.getTarget();
  18. Object[] args = invocation.getArgs();
  19. MappedStatement ms = (MappedStatement)args[0];
  20. Object parameter = args[1];
  21. RowBounds rowBounds = (RowBounds)args[2];
  22. ResultHandler resultHandler = (ResultHandler)args[3];
  23. CacheKey cacheKey;
  24. BoundSql boundSql;
  25. // 4个参数
  26. if (args.length == 4) {
  27. boundSql = ms.getBoundSql(parameter);
  28. cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
  29. }
  30. // 6个参数
  31. else {
  32. cacheKey = (CacheKey)args[4];
  33. boundSql = (BoundSql)args[5];
  34. }
  35. // 判断是否需要分页
  36. Page page = PageHelper.getPageThreadLocal();
  37. // 不执行分页
  38. if (page.getPageNum() <= 0) {
  39. return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
  40. }
  41. // count查询
  42. MappedStatement countMs = newCountMappedStatement(ms);
  43. String sql = boundSql.getSql();
  44. String countSql = "select count(1) from (" + sql + ") _count";
  45. BoundSql countBoundSql =
  46. new BoundSql(ms.getConfiguration(), countSql, boundSql.getParameterMappings(), parameter);
  47. Map<String, Object> additionalParameters = (Map<String, Object>)additionalParametersField.get(boundSql);
  48. for (Entry<String, Object> additionalParameter : additionalParameters.entrySet()) {
  49. countBoundSql.setAdditionalParameter(additionalParameter.getKey(), additionalParameter.getValue());
  50. }
  51. CacheKey countCacheKey = executor.createCacheKey(countMs, parameter, rowBounds, countBoundSql);
  52. Object countResult =
  53. executor.query(countMs, parameter, RowBounds.DEFAULT, resultHandler, countCacheKey, countBoundSql);
  54. Long count = (Long)((List)countResult).get(0);
  55. page.setTotal(count);
  56. // 分页查询
  57. String pageSql = sql + " limit " + page.getStartRow() + "," + page.getPageSize();
  58. BoundSql pageBoundSql =
  59. new BoundSql(ms.getConfiguration(), pageSql, boundSql.getParameterMappings(), parameter);
  60. for (Entry<String, Object> additionalParameter : additionalParameters.entrySet()) {
  61. pageBoundSql.setAdditionalParameter(additionalParameter.getKey(), additionalParameter.getValue());
  62. }
  63. CacheKey pageCacheKey = executor.createCacheKey(ms, parameter, rowBounds, pageBoundSql);
  64. List listResult = executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, pageCacheKey, pageBoundSql);
  65. page.addAll(listResult);
  66. // 清空线程局部变量分页信息
  67. PageHelper.pageThreadLocalClear();
  68. return page;
  69. }
  70. @Override
  71. public Object plugin(Object target) {
  72. return Plugin.wrap(target, this);
  73. }
  74. @Override
  75. public void setProperties(Properties properties) {
  76. try {
  77. additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
  78. additionalParametersField.setAccessible(true);
  79. } catch (NoSuchFieldException | SecurityException e) {
  80. e.printStackTrace();
  81. }
  82. }
  83. /**
  84. * 创建count的MappedStatement
  85. *
  86. * @param ms
  87. * @return
  88. */
  89. private MappedStatement newCountMappedStatement(MappedStatement ms) {
  90. MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId() + "_count",
  91. ms.getSqlSource(), ms.getSqlCommandType());
  92. builder.resource(ms.getResource());
  93. builder.fetchSize(ms.getFetchSize());
  94. builder.statementType(ms.getStatementType());
  95. builder.keyGenerator(ms.getKeyGenerator());
  96. if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
  97. StringBuilder keyProperties = new StringBuilder();
  98. for (String keyProperty : ms.getKeyProperties()) {
  99. keyProperties.append(keyProperty).append(",");
  100. }
  101. keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
  102. builder.keyProperty(keyProperties.toString());
  103. }
  104. builder.timeout(ms.getTimeout());
  105. builder.parameterMap(ms.getParameterMap());
  106. // count查询返回值int
  107. List<ResultMap> resultMaps = new ArrayList<ResultMap>();
  108. ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId() + "_count", Long.class,
  109. new ArrayList<ResultMapping>(0)).build();
  110. resultMaps.add(resultMap);
  111. builder.resultMaps(resultMaps);
  112. builder.resultSetType(ms.getResultSetType());
  113. builder.cache(ms.getCache());
  114. builder.flushCacheRequired(ms.isFlushCacheRequired());
  115. builder.useCache(ms.isUseCache());
  116. return builder.build();
  117. }
  118. }

配置

MyBatis 配置文件增加 plugin 配置项。

  1. <?xml version="1.0" encoding="UTF-8" ?>
  2. <!DOCTYPE configuration PUBLIC "-//mybatis.org//DTD Config 3.0//EN"
  3. "http://mybatis.org/dtd/mybatis-3-config.dtd">
  4. <configuration>
  5. <settings>
  6. <setting name="lazyLoadingEnabled" value="true"/>
  7. <setting name="aggressiveLazyLoading" value="false"/>
  8. </settings>
  9. <typeHandlers>
  10. <typeHandler javaType="com.yjw.demo.mybatis.common.constant.Sex"
  11. jdbcType="TINYINT"
  12. handler="com.yjw.demo.mybatis.common.type.SexEnumTypeHandler"/>
  13. </typeHandlers>
  14. <plugins>
  15. <plugin interceptor="com.yjw.demo.mybatis.common.page.PageInterceptor">
  16. </plugin>
  17. </plugins>
  18. </configuration>

作者:殷建卫 链接:https://www.yuque.com/yinjianwei/vyrvkf/qh6cfu 来源:殷建卫 - 架构笔记 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。