多线程事务?拿捏!
场景:有一批1万或者10万数据,插入数据库,怎么做
事务中进行批量提交
publList<List<OrderPo>> partition = Lists.partition(list, 450);StopWatch stopWatch = new StopWatch();stopWatch.start();// 顺序插入for (List<OrderPo> sub : partition) {orderMapper.batchSave(sub);}stopWatch.stop();log.info("耗时:" + stopWatch.getTotalTimeSeconds());
}
得出来的结果是 1万数据大概在5-6秒,10万数据在53-58秒
线程池并行插入
// 线程池
private static final ThreadPoolExecutor THREAD_POOL_EXECUTOR = new ThreadPoolExecutor(16, 16, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(1024), new ThreadFactory() {// 线程名字private final String PREFIX = "BATCH_INSERT_";// 计数器private AtomicLong atomicLong = new AtomicLong();@Overridepublic Thread newThread(Runnable r) {Thread thread = new Thread(null, r, PREFIX + atomicLong.incrementAndGet());return thread;}
});
@SneakyThrows
@Transactional(rollbackFor = Exception.class)
public void batchSave() {// 分批// 至于分多少批:PgSQL 的占位符个数是有限制的 不能超过 Short.MAX(32767)// 所以一批最多 = 32767 / 你的一行字段个数// 比如我这里 = 32767 / 66个字段 = 496 也就是一批最多496个数据List<List<OrderPo>> partition = Lists.partition(list, 450);CountDownLatch countDownLatch = new CountDownLatch(partition.size());StopWatch stopWatch = new StopWatch();stopWatch.start();// 顺序插入for (List<OrderPo> sub : partition) {THREAD_POOL_EXECUTOR.execute(() -> {try {log.info("线程:{}开始处理", Thread.currentThread().getName());orderMapper.batchSave(sub);} finally {countDownLatch.countDown();}});}// 等待插入完毕countDownLatch.await();stopWatch.stop();log.info("耗时:" + stopWatch.getTotalTimeSeconds());
}
这种方式会导致事务失效从而导致部分数据的丢失,因为内部通过线程池提交了多个子任务,这些子任务是异步执行的,事务的传播机制和线程的隔离性导致事务上下文不会传播到这些异步线程中
原因刨析:
- @Transactional是作用在当前线程的,事务上下文不会传播到其他线程去
- 子线程的批量保存操作是独立执行的,不受主线程事务控制
线程池并行插入但共用一个事务
实际上就是通过编程式事务来解决事务上下文不传播的问题,这种方式灵活性就高很多了,毕竟是在代码里直接编码,但是可扩展性一般
// 线程池
private static final ThreadPoolExecutor THREAD_POOL_EXECUTOR = new ThreadPoolExecutor(16, 16, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(1024), new ThreadFactory() {// 线程名字private final String PREFIX = "BATCH_INSERT_";// 计数器private AtomicLong atomicLong = new AtomicLong();@Overridepublic Thread newThread(Runnable r) {Thread thread = new Thread(null, r, PREFIX + atomicLong.incrementAndGet());return thread;}
});
@SneakyThrows
public void batchSave() {// 分批// 至于分多少批:PgSQL 的占位符个数是有限制的 不能超过 Short.MAX(32767)// 所以一批最多 = 32767 / 你的一行字段个数// 比如我这里 = 32767 / 66个字段 = 496 也就是一批最多496个数据List<List<OrderPo>> partition = Lists.partition(list, 450);// 手动事务提前创建出来DefaultTransactionDefinition transactionDefinition = new DefaultTransactionDefinition();transactionDefinition.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);// 提前获取连接TransactionStatus transactionStatus = dataSourceTransactionManager.getTransaction(transactionDefinition);// 获取数据源以及连接 供多线程使用DataSource dataSource = dataSourceTransactionManager.getDataSource();Object resource = TransactionSynchronizationManager.getResource(dataSource);// 异常标志AtomicBoolean exceptionFlag = new AtomicBoolean(false);boolean poolExceptionFlag = false;// 计数器等待执行完毕CountDownLatch countDownLatch = new CountDownLatch(partition.size());StopWatch stopWatch = new StopWatch();stopWatch.start();// 顺序插入for (List<OrderPo> sub : partition) {try {THREAD_POOL_EXECUTOR.execute(() -> {try {// 如果没有发生异常if (exceptionFlag.get()) {log.info("有其他线程执行失败,后续无需执行,因为最终会回滚");return;}// 释放上次绑定的数据源连接try {TransactionSynchronizationManager.unbindResource(dataSource);} catch (Exception ignored){}// 装上本次使用的连接TransactionSynchronizationManager.bindResource(dataSource, resource);log.info("线程:{}开始处理", Thread.currentThread().getName());// 执行插入orderMapper.batchSave(sub);// 模拟异常if (ThreadLocalRandom.current().nextInt(3) == 1) {int i = 1/0;}} catch (Exception e) {// 发生异常设置异常标志log.error(String.format("线程:%s我发生了异常,e:%s", Thread.currentThread().getName(), e.getMessage()), e);exceptionFlag.set(true);} finally {// 不管是成功还是失败 都要计数器 -1countDownLatch.countDown();}});} catch (Exception e) {// 提交任务失败 那就是失败了exceptionFlag.set(true);log.info("当前线程池繁忙,请稍后重试");dataSourceTransactionManager.rollback(transactionStatus);poolExceptionFlag = true;break;}}// 等待执行完毕 这里有个隐患 等待多长时间呢? 线程池任务过多的话最严重的情况 就是一直要在这里阻塞// 因为事务的提交还是回滚都交给了 主任务线程// 如果提交到线程池都成功了的话 就等待都执行完if (!poolExceptionFlag) {countDownLatch.await();}// 异常标志来做提交还是回滚if (exceptionFlag.get()) {// 发生异常 回滚dataSourceTransactionManager.rollback(transactionStatus);} else {// 未发生异常 可以提交dataSourceTransactionManager.commit(transactionStatus);}stopWatch.stop();log.info("耗时:" + stopWatch.getTotalTimeSeconds());
}
@Async + @Transactional结合实现
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.scheduling.annotation.AsyncResult;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;@Service
public class OrderService {@Autowiredprivate OrderMapper orderMapper;// 主方法,负责分批并启动异步任务@Transactional(rollbackFor = Exception.class)public void batchSave(List<OrderPo> list) throws InterruptedException {// 分批List<List<OrderPo>> partition = Lists.partition(list, 450);CountDownLatch countDownLatch = new CountDownLatch(partition.size());StopWatch stopWatch = new StopWatch();stopWatch.start();// 启动异步任务List<CompletableFuture<Void>> futures = partition.stream().map(subList -> CompletableFuture.runAsync(() -> batchSaveAsync(subList, countDownLatch))).collect(Collectors.toList());// 等待所有异步任务完成CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));try {// 等待所有异步任务完成,设置超时时间避免无限等待allFutures.get(60, TimeUnit.SECONDS); // 设置合理的超时时间} catch (Exception e) {throw new RuntimeException("Batch save failed", e);}stopWatch.stop();System.out.println("Total time taken: " + stopWatch.getTotalTimeSeconds());}// 异步执行的子任务@Async@Transactional(rollbackFor = Exception.class, propagation = Propagation.REQUIRES_NEW)public void batchSaveAsync(List<OrderPo> subList, CountDownLatch countDownLatch) {try {System.out.println("Thread: " + Thread.currentThread().getName() + " is processing batch");orderMapper.batchSave(subList);// 模拟异常if (ThreadLocalRandom.current().nextInt(3) == 1) {throw new RuntimeException("Simulated exception in batch save");}} catch (Exception e) {System.err.println("Exception occurred in thread: " + Thread.currentThread().getName() + ", " + e.getMessage());throw e; // 异常会触发事务回滚} finally {countDownLatch.countDown();}}
}
- 事务传播机制:是指当一个事务方法被另一个事务方法调用时,这个事务方法应该如何进行事务控制。例如,常见的事务传播行为有 REQUIRED(如果当前没有事务,就新建一个事务;如果已经存在一个事务,就加入到这个事务中)、REQUIRES_NEW(新建事务,如果当前存在事务,就把当前事务挂起)等。