springboot分页查询并行优化实践

打印 上一主题 下一主题

主题 2009|帖子 2009|积分 6037

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
            ——基于异步优化与 MyBatis-Plus 分页插件思想的实践
适用场景


  • 数据量较大的单表分页查询
  • 较复杂的多表关联查询,包含group by等无法举行count优化较耗时的分页查询
技术栈


  • 核心框架:Spring Boot + MyBatis-Plus
  • 异步编程:JDK 8+ 的 CompletableFuture 
  • 数据库:MySQL 8.0
  • 线程池:自定义线程池管理并行任务(如 ThreadPoolTaskExecutor)
实现思绪

解决传统分页查询中 串行实行 COUNT 与数据查询 的性能瓶颈,通过 并行化 淘汰总耗时,同时兼容复杂查询场景(如多表关联、DISTINCT 等)
兼容mybatisPlus分页参数,复用 IPage 接口定义分页参数(当前页、每页条数),
借鉴 MyBatis-Plus 的 PaginationInnerInterceptor,通过实现 MyBatis 的 Interceptor 接口,拦截 Executor#query 方法,动态修改 SQL,
sql优化适配:COUNT 优化:自动移除 ORDER BY,保存 GROUP BY 和 DISTINCT(需包裹子查询),数据查询:保存完整 SQL 逻辑,仅追加 LIMIT 和 OFFSET。
直接上代码

使用简单
调用查询方法前赋值page对象属性total大于0数值则可进入自定义分页查询方案。
  1. //示例代码
  2. Page<User> page = new Page<>(1,10);
  3. page.setTotal(1L);
复制代码
线程池配置
  1. @Configuration
  2. public class ThreadPoolTaskExecutorConfig {
  3.     public static final Integer CORE_POOL_SIZE = 20;
  4.     public static final Integer MAX_POOL_SIZE = 40;
  5.     public static final Integer QUEUE_CAPACITY = 200;
  6.     public static final Integer KEEP_ALIVE_SECONDS = 60;
  7.     @Bean("threadPoolTaskExecutor")
  8.     public ThreadPoolTaskExecutor getThreadPoolTaskExecutor() {
  9.         ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
  10.         //核心线程数
  11.         threadPoolTaskExecutor.setCorePoolSize(CORE_POOL_SIZE);
  12.         //线程池最大线程数
  13.         threadPoolTaskExecutor.setMaxPoolSize(MAX_POOL_SIZE);
  14.         //队列容量
  15.         threadPoolTaskExecutor.setQueueCapacity(QUEUE_CAPACITY);
  16.         //线程空闲存活时间
  17.         threadPoolTaskExecutor.setKeepAliveSeconds(KEEP_ALIVE_SECONDS);
  18.         //线程前缀
  19.         threadPoolTaskExecutor.setThreadNamePrefix("commonTask-");
  20.         //拒绝策略
  21.         threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
  22.         //线程池初始化
  23.         threadPoolTaskExecutor.initialize();
  24.         return threadPoolTaskExecutor;
  25.     }
  26.     @Bean("countAsyncThreadPool")
  27.     public ThreadPoolTaskExecutor getCountAsyncThreadPool() {
  28.         ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
  29.         //核心线程数,根据负载动态调整
  30.         threadPoolTaskExecutor.setCorePoolSize(6);
  31.         //线程池最大线程数,根据负载动态调整
  32.         threadPoolTaskExecutor.setMaxPoolSize(12);
  33.         //队列容量  队列容量不宜过多,根据负载动态调整
  34.         threadPoolTaskExecutor.setQueueCapacity(2);
  35.         //线程空闲存活时间
  36.         threadPoolTaskExecutor.setKeepAliveSeconds(KEEP_ALIVE_SECONDS);
  37.         //线程前缀
  38.         threadPoolTaskExecutor.setThreadNamePrefix("countAsync-");
  39.         //拒绝策略  队列满时由调用者主线程执行
  40.         threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
  41.         //线程池初始化
  42.         threadPoolTaskExecutor.initialize();
  43.         return threadPoolTaskExecutor;
  44.     }
  45. }
复制代码
mybatis-plus配置类
  1. @Configuration
  2. @MapperScan("com.xxx.mapper")
  3. public class MybatisPlusConfig {
  4.     @Resource
  5.     ThreadPoolTaskExecutor countAsyncThreadPool;
  6.     @Resource
  7.     ApplicationContext applicationContext;
  8.     @Bean
  9.     public MybatisPlusInterceptor mybatisPlusInterceptor() {
  10.         MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
  11.         interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
  12.         return interceptor;
  13.     }
  14.     @Bean
  15.     public PageParallelQueryInterceptor pageParallelQueryInterceptor() {
  16.         PageParallelQueryInterceptor pageParallelQueryInterceptor = new PageParallelQueryInterceptor();
  17.         pageParallelQueryInterceptor.setCountAsyncThreadPool(countAsyncThreadPool);
  18.         pageParallelQueryInterceptor.setApplicationContext(applicationContext);
  19.         return pageParallelQueryInterceptor;
  20.     }
  21. }
复制代码
自定义mybatis拦截器
  1. package com.example.dlock_demo.interceptor;
  2. import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
  3. import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
  4. import lombok.extern.slf4j.Slf4j;
  5. import net.sf.jsqlparser.JSQLParserException;
  6. import net.sf.jsqlparser.expression.Expression;
  7. import net.sf.jsqlparser.parser.CCJSqlParserUtil;
  8. import net.sf.jsqlparser.statement.select.*;
  9. import org.apache.ibatis.builder.StaticSqlSource;
  10. import org.apache.ibatis.cache.CacheKey;
  11. import org.apache.ibatis.executor.Executor;
  12. import org.apache.ibatis.mapping.BoundSql;
  13. import org.apache.ibatis.mapping.MappedStatement;
  14. import org.apache.ibatis.mapping.ResultMap;
  15. import org.apache.ibatis.plugin.Interceptor;
  16. import org.apache.ibatis.plugin.Intercepts;
  17. import org.apache.ibatis.plugin.Invocation;
  18. import org.apache.ibatis.plugin.Signature;
  19. import org.apache.ibatis.session.*;
  20. import org.springframework.context.ApplicationContext;
  21. import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
  22. import java.lang.reflect.Field;
  23. import java.lang.reflect.Method;
  24. import java.sql.SQLException;
  25. import java.util.*;
  26. import java.util.concurrent.CompletableFuture;
  27. import java.util.concurrent.CompletionException;
  28. import java.util.concurrent.ConcurrentHashMap;
  29. /**
  30. * Mybatis-分页并行查询拦截器
  31. *
  32. * @author shf
  33. */
  34. @Intercepts({
  35.         @Signature(type = Executor.class, method = "query",
  36.                 args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
  37.         @Signature(type = Executor.class, method = "query",
  38.                 args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
  39. })
  40. @Slf4j
  41. public class PageParallelQueryInterceptor implements Interceptor {
  42.     /**
  43.      * 用于数据库并行查询线程池
  44.      */
  45.     private ThreadPoolTaskExecutor countAsyncThreadPool;
  46.     /**
  47.      * 容器上下文
  48.      */
  49.     private ApplicationContext applicationContext;
  50.     private static final String LONG_RESULT_MAP_ID = "twoPhase-Long-ResultMap";
  51.     private static final Map<String, MappedStatement> twoPhaseMsCache = new ConcurrentHashMap();
  52.     public void setCountAsyncThreadPool(ThreadPoolTaskExecutor countAsyncThreadPool) {
  53.         this.countAsyncThreadPool = countAsyncThreadPool;
  54.     }
  55.     public void setApplicationContext(ApplicationContext applicationContext) {
  56.         this.applicationContext = applicationContext;
  57.     }
  58.     @Override
  59.     public Object intercept(Invocation invocation) throws Throwable {
  60.         Object[] args = invocation.getArgs();
  61.         MappedStatement ms = (MappedStatement) args[0];
  62.         Object parameter = args[1];
  63.         //获取分页参数
  64.         Page<?> page = getPageParameter(parameter);
  65.         if (page == null || page.getSize() <= 0 || !page.searchCount() || page.getTotal() == 0) {
  66.             return invocation.proceed();
  67.         }
  68.         //获取Mapper方法(注解形式 需利用反射且只能应用在mapper接口层,不推荐使用)
  69.         /*Method method = getMapperMethod(ms);
  70.         if (method == null || !method.isAnnotationPresent(PageParallelQuery.class)) {
  71.             return invocation.proceed();
  72.         }*/
  73.         BoundSql boundSql = ms.getBoundSql(parameter);
  74.         String originalSql = boundSql.getSql();
  75.         //禁用mybatis plus PaginationInnerInterceptor count查询
  76.         page.setSearchCount(false);
  77.         page.setTotal(0);
  78.         args[2] = RowBounds.DEFAULT;
  79.         CompletableFuture<Long> countFuture = resolveCountCompletableFuture(invocation, originalSql);
  80.         //limit查询
  81.         long startTime = System.currentTimeMillis();
  82.         Object proceed = invocation.proceed();
  83.         log.info("原SQL数据查询-耗时={}", System.currentTimeMillis() - startTime);
  84.         page.setTotal(countFuture.get());
  85.         return proceed;
  86.     }
  87.     private CompletableFuture<Long> resolveCountCompletableFuture(Invocation invocation, String originalSql) {
  88.         return CompletableFuture.supplyAsync(() -> {
  89.             try {
  90.                 //查询总条数
  91.                 long startTime = System.currentTimeMillis();
  92.                 long total = executeCountQuery(originalSql, invocation);
  93.                 log.info("分页并行查询COUNT总条数[{}]-耗时={}", total, System.currentTimeMillis() - startTime);
  94.                 return total;
  95.             } catch (Throwable e) {
  96.                 log.error("page parallel query exception:", e);
  97.                 throw new CompletionException(e);
  98.             }
  99.         }, countAsyncThreadPool).exceptionally(throwable -> {
  100.             log.error("page parallel query exception:", throwable);
  101.             return 0L;
  102.         });
  103.     }
  104.     private CompletableFuture<Object> resolveOriginalProceedCompletableFuture(Invocation invocation) {
  105.         return CompletableFuture.supplyAsync(() -> {
  106.             try {
  107.                 long startTime = System.currentTimeMillis();
  108.                 Object proceed = invocation.proceed();
  109.                 log.info("原SQL数据查询-耗时={}", System.currentTimeMillis() - startTime);
  110.                 return proceed;
  111.             } catch (Throwable e) {
  112.                 throw new CompletionException(e);
  113.             }
  114.         }, countAsyncThreadPool).exceptionally(throwable -> {
  115.             log.error("page parallel query original proceed exception:", throwable);
  116.             return null;
  117.         });
  118.     }
  119.     /**
  120.      * 执行count查询
  121.      */
  122.     private long executeCountQuery(String originalSql, Invocation invocation)
  123.             throws JSQLParserException, SQLException {
  124.         //解析并修改SQL为count查询
  125.         Select countSelect = (Select) CCJSqlParserUtil.parse(originalSql);
  126.         PlainSelect plainSelect = (PlainSelect) countSelect.getSelectBody();
  127.         //修改select为count(*)
  128.         /*plainSelect.setSelectItems(Collections.singletonList(
  129.                 new SelectExpressionItem(new Function("COUNT", new Column("*")))
  130.         );*/
  131.         // 移除排序和分页
  132.         Distinct distinct = plainSelect.getDistinct();
  133.         GroupByElement groupBy = plainSelect.getGroupBy();
  134.         String countSql = "";
  135.         if (groupBy == null && distinct == null) {
  136.             Expression countFuncExpression = CCJSqlParserUtil.parseExpression("COUNT(*)");
  137.             plainSelect.setSelectItems(Collections.singletonList(
  138.                     new SelectExpressionItem(countFuncExpression)));
  139.             plainSelect.setOrderByElements(null);
  140.             countSql = plainSelect.toString();
  141.         } else if (groupBy != null) {
  142.             plainSelect.setLimit(null);
  143.             plainSelect.setOffset(null);
  144.             countSql = "SELECT COUNT(*) FROM (" + plainSelect + ") TOTAL";
  145.         } else {
  146.             plainSelect.setOrderByElements(null);
  147.             plainSelect.setLimit(null);
  148.             plainSelect.setOffset(null);
  149.             countSql = "SELECT COUNT(*) FROM (" + plainSelect + ") TOTAL";
  150.         }
  151.         //执行count查询
  152.         return doCountQuery(invocation, countSql);
  153.     }
  154.     /**
  155.      * 执行修改后的COUNT(*)-SQL查询
  156.      */
  157.     @SuppressWarnings("unchecked")
  158.     private Long doCountQuery(Invocation invocation, String modifiedSql) {
  159.         //Executor executor = (Executor) invocation.getTarget();
  160.         //创建新会话(自动获取新连接)
  161.         Executor executor;
  162.         SqlSessionFactory sqlSessionFactory = applicationContext.getBean(SqlSessionFactory.class);
  163.         try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.SIMPLE)) {
  164.             //com.alibaba.druid.pool.DruidPooledConnection
  165.             System.out.println("新会话Connection class: " + sqlSession.getConnection().getClass().getName());
  166.             Field executorField = sqlSession.getClass().getDeclaredField("executor");
  167.             executorField.setAccessible(true);
  168.             executor = (Executor) executorField.get(sqlSession);
  169.             Object[] args = invocation.getArgs();
  170.             MappedStatement originalMs = (MappedStatement) args[0];
  171.             Object parameter = args[1];
  172.             //创建新的查询参数
  173.             Map<String, Object> newParameter = new HashMap<>();
  174.             if (parameter instanceof Map) {
  175.                 // 复制原始参数但移除分页参数
  176.                 Map<?, ?> originalParams = (Map<?, ?>) parameter;
  177.                 originalParams.forEach((k, v) -> {
  178.                     if (!(v instanceof Page)) {
  179.                         newParameter.put(k.toString(), v);
  180.                     }
  181.                 });
  182.             }
  183.             //创建新的BoundSql
  184.             BoundSql originalBoundSql = originalMs.getBoundSql(newParameter);
  185.             BoundSql newBoundSql = new BoundSql(originalMs.getConfiguration(), modifiedSql, originalBoundSql.getParameterMappings(), newParameter);
  186.             //复制原始参数值
  187.             originalBoundSql.getParameterMappings().forEach(mapping -> {
  188.                 String prop = mapping.getProperty();
  189.                 if (mapping.getJavaType().isInstance(newParameter)) {
  190.                     newBoundSql.setAdditionalParameter(prop, newParameter);
  191.                 } else if (newParameter instanceof Map) {
  192.                     Object value = ((Map<?, ?>) newParameter).get(prop);
  193.                     newBoundSql.setAdditionalParameter(prop, value);
  194.                 }
  195.             });
  196.             //创建新的BoundSql
  197.             /*BoundSql originalBoundSql = originalMs.getBoundSql(parameter);
  198.             BoundSql newBoundSql = new BoundSql(originalMs.getConfiguration(), modifiedSql,
  199.                     originalBoundSql.getParameterMappings(), parameter);*/
  200.             Configuration configuration = originalMs.getConfiguration();
  201.             //创建临时ResultMap
  202.             ResultMap resultMap = new ResultMap.Builder(
  203.                     configuration,
  204.                     LONG_RESULT_MAP_ID,
  205.                     //强制指定结果类型
  206.                     Long.class,
  207.                     //自动映射列到简单类型
  208.                     Collections.emptyList()
  209.             ).build();
  210.             if (!configuration.hasResultMap(LONG_RESULT_MAP_ID)) {
  211.                 configuration.addResultMap(resultMap);
  212.             }
  213.             String countMsId = originalMs.getId() + "_countMsId";
  214.             MappedStatement mappedStatement = twoPhaseMsCache.computeIfAbsent(countMsId, (key) ->
  215.                     this.getNewMappedStatement(modifiedSql, originalMs, newBoundSql, resultMap, countMsId));
  216.             //执行查询
  217.             List<Object> result = executor.query(mappedStatement, newParameter, RowBounds.DEFAULT, (ResultHandler<?>) args[3]);
  218.             long total = 0L;
  219.             if (CollectionUtils.isNotEmpty(result)) {
  220.                 Object o = result.get(0);
  221.                 if (o != null) {
  222.                     total = Long.parseLong(o.toString());
  223.                 }
  224.             }
  225.             return total;
  226.         } catch (Throwable e) {
  227.             log.error("分页并行查询-executeCountQuery异常:", e);
  228.         }
  229.         return 0L;
  230.     }
  231.     private MappedStatement getNewMappedStatement(String modifiedSql, MappedStatement originalMs, BoundSql newBoundSql,
  232.                                                   ResultMap resultMap, String msId) {
  233.         //创建新的MappedStatement
  234.         MappedStatement.Builder builder = new MappedStatement.Builder(
  235.                 originalMs.getConfiguration(),
  236.                 msId,
  237.                 new StaticSqlSource(originalMs.getConfiguration(), modifiedSql, newBoundSql.getParameterMappings()),
  238.                 originalMs.getSqlCommandType()
  239.         );
  240.         //复制重要属性
  241.         builder.resource(originalMs.getResource())
  242.                 .fetchSize(originalMs.getFetchSize())
  243.                 .timeout(originalMs.getTimeout())
  244.                 .statementType(originalMs.getStatementType())
  245.                 .keyGenerator(originalMs.getKeyGenerator())
  246.                 .keyProperty(originalMs.getKeyProperties() == null ? null : String.join(",", originalMs.getKeyProperties()))
  247.                 .resultMaps(resultMap == null ? originalMs.getResultMaps() : Collections.singletonList(resultMap))
  248.                 .parameterMap(originalMs.getParameterMap())
  249.                 .resultSetType(originalMs.getResultSetType())
  250.                 .cache(originalMs.getCache())
  251.                 .flushCacheRequired(originalMs.isFlushCacheRequired())
  252.                 .useCache(originalMs.isUseCache());
  253.         return builder.build();
  254.     }
  255.     /**
  256.      * 获取分页参数
  257.      */
  258.     private Page<?> getPageParameter(Object parameter) {
  259.         if (parameter instanceof Map) {
  260.             Map<?, ?> paramMap = (Map<?, ?>) parameter;
  261.             return (Page<?>) paramMap.values().stream()
  262.                     .filter(p -> p instanceof Page)
  263.                     .findFirst()
  264.                     .orElse(null);
  265.         }
  266.         return parameter instanceof Page ? (Page<?>) parameter : null;
  267.     }
  268.     /**
  269.      * 获取Mapper方法
  270.      */
  271.     private Method getMapperMethod(MappedStatement ms) {
  272.         try {
  273.             String methodName = ms.getId().substring(ms.getId().lastIndexOf(".") + 1);
  274.             Class<?> mapperClass = Class.forName(ms.getId().substring(0, ms.getId().lastIndexOf(".")));
  275.             return Arrays.stream(mapperClass.getMethods())
  276.                     .filter(m -> m.getName().equals(methodName))
  277.                     .findFirst()
  278.                     .orElse(null);
  279.         } catch (ClassNotFoundException e) {
  280.             return null;
  281.         }
  282.     }
  283. }
复制代码
注意事项

有人可能会担心并行查询,在高并发场景可能会导致count查询与limit数据查询不一致,但其实只要没有锁,只要是分开的两条sql查询,原mybatisplus分页插件也一样面临这个题目。
count优化没有举行join语句判定优化,相当于主动关闭了page.setOptimizeJoinOfCountSql(false);在一对多等场景可能会造成count查询有误,Mybatisplus官网也有干系提示,所以这里干脆舍弃了。
mybatisplus版本差别,可能会导致JsqlParser所使用的api有所差别,需要本身对应版本修改下。本篇版本使用的3.5.1
关于线程池的线程数设置顺便提一下:
网上流行一个说法:
1. CPU 密集型任务
特点:任务主要消耗 CPU 资源(如复杂计算、图像处理)。
线程数建议:

  • 核心线程数:CPU 核心数 + 1(或等于CPU核心数,避免上下文切换过多)。
  • 最大线程数:与核心线程数相同(防止过多线程竞争 CPU)。
2. I/O 密集型任务
特点:任务涉及大量等待(如网络哀求、数据库读写)。
线程数建议:

  • 核心线程数:2 * CPU 核心数(确保正常负载下的高效处理)。
  • 最大线程数:根据体系资源调整(用于应对突发高并发)。
其实这个说法来源于一个经验公式推导而来:
threads = CPU核心数 * (1 + 平均等待时间 / 平均计算时间)
《Java 虚拟机并发编程》中先容


 
另一篇:《Java Concurrency in Practice》即《java并发编程实践》,给出的线程池巨细的估算公式:
 

Nthreads=Ncpu*Ucpu*(1+w/c),其中 Ncpu=CPU核心数,Ucpu=cpu使用率,0~1;W/C=等待时间与计算时间的比率
细致推导两个公式,其实类似,在cpu使用率达100%时,其实结论是一致的,这时间计算线程数的公式就成了,Nthreads=Ncpu*100%*(1+w/c) =Ncpu*(1+w/c)。
那么在实践应用中计算的公式就出来了,【以下推算,不思量内存消耗等方面】,如下:
1、针对IO密集型,阻塞耗时w一般都是计算耗时几倍c,假设阻塞耗时=计算耗时的情况下,Nthreads=Ncpu*(1+1)=2Ncpu,所以这种情况下,建议思量2倍的CPU核心数做为线程数
2、对于计算密集型,阻塞耗时趋于0,即w/c趋于0,公式Nthreads = Ncpu。
现实应用时要思量同时设置了几个隔离线程池,另外tomcat自带的线程池也会共享宿主机公共资源。
 

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

篮之新喜

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表