BaseDao 通用更新方法设计与实现
一、BaseDao 通用更新方法设计
1. 核心功能需求功能 说明 必要性 通用插入 插入任意实体对象 ⭐⭐⭐⭐⭐ 通用更新 根据主键更新实体 ⭐⭐⭐⭐⭐ 通用删除 根据主键删除记录 ⭐⭐⭐⭐ 批量操作 批量插入/更新/删除 ⭐⭐⭐⭐ 条件更新 根据条件更新字段 ⭐⭐⭐ 动态SQL 支持非空字段更新 ⭐⭐⭐
2. 类结构设计
操作
«abstract»
BaseDao<T, ID>
-DataSource dataSource
+BaseDao(DataSource dataSource)
+insert(T entity) : ID
+update(T entity) : int
+deleteById(ID id) : int
+batchInsert(List<T> entities) : int[]
+batchUpdate(List<T> entities) : int[]
+updateSelective(T entity) : int
+executeUpdate(String sql, Object... params) : int
UserDao
+UserDao(DataSource dataSource)
+findByEmail(String email) : List<User>
User
二、完整实现代码
1. 反射工具类 (ReflectionUtils)
import java. lang. reflect. Field ;
import java. util. ArrayList ;
import java. util. List ; public class ReflectionUtils { public static List < Field > getAllFields ( Class < ? > clazz) { List < Field > fields = new ArrayList < > ( ) ; while ( clazz != null && clazz != Object . class ) { for ( Field field : clazz. getDeclaredFields ( ) ) { fields. add ( field) ; } clazz = clazz. getSuperclass ( ) ; } return fields; } public static Object getFieldValue ( Object obj, String fieldName) { try { Field field = obj. getClass ( ) . getDeclaredField ( fieldName) ; field. setAccessible ( true ) ; return field. get ( obj) ; } catch ( Exception e) { throw new RuntimeException ( "Failed to get field value" , e) ; } } public static boolean isFieldNull ( Object obj, String fieldName) { return getFieldValue ( obj, fieldName) == null ; }
}
2. BaseDao 基础实现
import javax. sql. DataSource ;
import java. lang. reflect. Field ;
import java. sql. * ;
import java. util. * ; public abstract class BaseDao < T , ID> { protected final DataSource dataSource; protected final Class < T > entityClass; @SuppressWarnings ( "unchecked" ) public BaseDao ( DataSource dataSource) { this . dataSource = dataSource; this . entityClass = ( Class < T > ) ( ( java. lang. reflect. ParameterizedType) getClass ( ) . getGenericSuperclass ( ) ) . getActualTypeArguments ( ) [ 0 ] ; } public ID insert ( T entity) { String tableName = getTableName ( ) ; List < Field > fields = getInsertableFields ( entity) ; String sql = generateInsertSql ( tableName, fields) ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql, Statement . RETURN_GENERATED_KEYS) ) { setParameters ( pstmt, entity, fields, 1 ) ; int affectedRows = pstmt. executeUpdate ( ) ; if ( affectedRows == 0 ) { throw new SQLException ( "Insert failed, no rows affected." ) ; } try ( ResultSet generatedKeys = pstmt. getGeneratedKeys ( ) ) { if ( generatedKeys. next ( ) ) { @SuppressWarnings ( "unchecked" ) ID id = ( ID) generatedKeys. getObject ( 1 ) ; setPrimaryKeyValue ( entity, id) ; return id; } else { throw new SQLException ( "Insert failed, no ID obtained." ) ; } } } catch ( SQLException e) { throw new RuntimeException ( "Insert operation failed" , e) ; } } public int update ( T entity) { String tableName = getTableName ( ) ; String primaryKey = getPrimaryKeyName ( ) ; List < Field > fields = ReflectionUtils . getAllFields ( entityClass) ; String sql = generateUpdateSql ( tableName, primaryKey, fields) ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { int paramIndex = setParameters ( pstmt, entity, fields, 1 ) ; Object idValue = ReflectionUtils . getFieldValue ( entity, primaryKey) ; pstmt. setObject ( paramIndex, idValue) ; return pstmt. executeUpdate ( ) ; } catch ( SQLException e) { throw new RuntimeException ( "Update operation failed" , e) ; } } public int deleteById ( ID id) { String tableName = getTableName ( ) ; String primaryKey = getPrimaryKeyName ( ) ; String sql = "DELETE FROM " + tableName + " WHERE " + primaryKey + " = ?" ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { pstmt. setObject ( 1 , id) ; return pstmt. executeUpdate ( ) ; } catch ( SQLException e) { throw new RuntimeException ( "Delete operation failed" , e) ; } } public int updateSelective ( T entity) { String tableName = getTableName ( ) ; String primaryKey = getPrimaryKeyName ( ) ; List < Field > fields = getNonEmptyFields ( entity) ; if ( fields. isEmpty ( ) ) { throw new IllegalArgumentException ( "No non-empty fields to update" ) ; } String sql = generateUpdateSql ( tableName, primaryKey, fields) ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { int paramIndex = setParameters ( pstmt, entity, fields, 1 ) ; Object idValue = ReflectionUtils . getFieldValue ( entity, primaryKey) ; pstmt. setObject ( paramIndex, idValue) ; return pstmt. executeUpdate ( ) ; } catch ( SQLException e) { throw new RuntimeException ( "Selective update failed" , e) ; } } public int [ ] batchInsert ( List < T > entities) { if ( entities == null || entities. isEmpty ( ) ) { return new int [ 0 ] ; } String tableName = getTableName ( ) ; List < Field > fields = getInsertableFields ( entities. get ( 0 ) ) ; String sql = generateInsertSql ( tableName, fields) ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { for ( T entity : entities) { setParameters ( pstmt, entity, fields, 1 ) ; pstmt. addBatch ( ) ; } return pstmt. executeBatch ( ) ; } catch ( SQLException e) { throw new RuntimeException ( "Batch insert failed" , e) ; } } public int executeUpdate ( String sql, Object . . . params) { try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { for ( int i = 0 ; i < params. length; i++ ) { pstmt. setObject ( i + 1 , params[ i] ) ; } return pstmt. executeUpdate ( ) ; } catch ( SQLException e) { throw new RuntimeException ( "Execute update failed" , e) ; } } protected String getTableName ( ) { String className = entityClass. getSimpleName ( ) ; return camelToSnake ( className) ; } protected String getPrimaryKeyName ( ) { return "id" ; } private String camelToSnake ( String str) { return str. replaceAll ( "([a-z])([A-Z])" , "$1_$2" ) . toLowerCase ( ) ; } private String generateInsertSql ( String tableName, List < Field > fields) { StringBuilder columns = new StringBuilder ( ) ; StringBuilder placeholders = new StringBuilder ( ) ; for ( Field field : fields) { String columnName = camelToSnake ( field. getName ( ) ) ; columns. append ( columnName) . append ( "," ) ; placeholders. append ( "?," ) ; } columns. setLength ( columns. length ( ) - 1 ) ; placeholders. setLength ( placeholders. length ( ) - 1 ) ; return "INSERT INTO " + tableName + " (" + columns + ") VALUES (" + placeholders + ")" ; } private String generateUpdateSql ( String tableName, String primaryKey, List < Field > fields) { StringBuilder setClause = new StringBuilder ( "UPDATE " ) . append ( tableName) . append ( " SET " ) ; for ( Field field : fields) { String columnName = camelToSnake ( field. getName ( ) ) ; setClause. append ( columnName) . append ( " = ?," ) ; } setClause. setLength ( setClause. length ( ) - 1 ) ; setClause. append ( " WHERE " ) . append ( primaryKey) . append ( " = ?" ) ; return setClause. toString ( ) ; } private int setParameters ( PreparedStatement pstmt, T entity, List < Field > fields, int startIndex) throws SQLException { int paramIndex = startIndex; for ( Field field : fields) { field. setAccessible ( true ) ; try { Object value = field. get ( entity) ; pstmt. setObject ( paramIndex++ , value) ; } catch ( IllegalAccessException e) { throw new SQLException ( "Failed to get field value" , e) ; } } return paramIndex; } private List < Field > getNonEmptyFields ( T entity) { List < Field > allFields = ReflectionUtils . getAllFields ( entityClass) ; List < Field > nonEmptyFields = new ArrayList < > ( ) ; for ( Field field : allFields) { if ( ! ReflectionUtils . isFieldNull ( entity, field. getName ( ) ) && ! field. getName ( ) . equals ( getPrimaryKeyName ( ) ) ) { nonEmptyFields. add ( field) ; } } return nonEmptyFields; } private List < Field > getInsertableFields ( T entity) { List < Field > allFields = ReflectionUtils . getAllFields ( entityClass) ; List < Field > insertableFields = new ArrayList < > ( ) ; for ( Field field : allFields) { if ( ! field. getName ( ) . equals ( getPrimaryKeyName ( ) ) ) { insertableFields. add ( field) ; } } return insertableFields; } private void setPrimaryKeyValue ( T entity, ID id) { try { Field primaryKeyField = entityClass. getDeclaredField ( getPrimaryKeyName ( ) ) ; primaryKeyField. setAccessible ( true ) ; primaryKeyField. set ( entity, id) ; } catch ( Exception e) { throw new RuntimeException ( "Failed to set primary key value" , e) ; } }
}
3. 具体DAO实现示例 (UserDao)
public class UserDao extends BaseDao < User , Long > { public UserDao ( DataSource dataSource) { super ( dataSource) ; } public List < User > findByEmail ( String email) { String sql = "SELECT * FROM user WHERE email = ?" ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { pstmt. setString ( 1 , email) ; try ( ResultSet rs = pstmt. executeQuery ( ) ) { List < User > users = new ArrayList < > ( ) ; while ( rs. next ( ) ) { User user = new User ( ) ; user. setId ( rs. getLong ( "id" ) ) ; user. setUsername ( rs. getString ( "username" ) ) ; user. setEmail ( rs. getString ( "email" ) ) ; user. setCreateTime ( rs. getTimestamp ( "create_time" ) . toLocalDateTime ( ) ) ; users. add ( user) ; } return users; } } catch ( SQLException e) { throw new RuntimeException ( "Query failed" , e) ; } }
}
4. 实体类示例 (User)
import java. time. LocalDateTime ; public class User { private Long id; private String username; private String email; private LocalDateTime createTime; public User ( ) { } public User ( String username, String email) { this . username = username; this . email = email; this . createTime = LocalDateTime . now ( ) ; }
}
三、使用示例
1. 基础CRUD操作
HikariConfig config = new HikariConfig ( ) ;
config. setJdbcUrl ( "jdbc:mysql://localhost:3306/mydb" ) ;
config. setUsername ( "user" ) ;
config. setPassword ( "password" ) ;
DataSource dataSource = new HikariDataSource ( config) ;
UserDao userDao = new UserDao ( dataSource) ;
User newUser = new User ( "john_doe" , "john@example.com" ) ;
Long userId = userDao. insert ( newUser) ;
System . out. println ( "Inserted user ID: " + userId) ;
newUser. setEmail ( "john.new@example.com" ) ;
int updatedRows = userDao. update ( newUser) ;
System . out. println ( "Updated rows: " + updatedRows) ;
User partialUpdate = new User ( ) ;
partialUpdate. setId ( userId) ;
partialUpdate. setEmail ( "john.partial@example.com" ) ;
int selectiveUpdated = userDao. updateSelective ( partialUpdate) ;
int deletedRows = userDao. deleteById ( userId) ;
System . out. println ( "Deleted rows: " + deletedRows) ;
2. 批量操作
List < User > users = Arrays . asList ( new User ( "user1" , "user1@example.com" ) , new User ( "user2" , "user2@example.com" ) , new User ( "user3" , "user3@example.com" )
) ; int [ ] insertResults = userDao. batchInsert ( users) ;
System . out. println ( "Batch insert results: " + Arrays . toString ( insertResults) ) ;
users. forEach ( user -> user. setEmail ( user. getUsername ( ) + "@newdomain.com" ) ) ;
int [ ] updateResults = userDao. batchUpdate ( users) ;
System . out. println ( "Batch update results: " + Arrays . toString ( updateResults) ) ;
3. 自定义SQL执行
int rowsAffected = userDao. executeUpdate ( "UPDATE user SET status = ? WHERE create_time < ?" , "INACTIVE" , LocalDateTime . now ( ) . minusYears ( 1 )
) ;
System . out. println ( "Custom update affected: " + rowsAffected + " rows" ) ;
四、设计注意事项
1. 性能优化要点优化点 实现策略 注意事项 连接管理 使用连接池 (HikariCP) 配置合适连接数 批量操作 JDBC批处理 控制批处理大小 反射缓存 缓存Field元数据 避免重复获取 SQL构建 预编译SQL模板 防止SQL注入 资源释放 try-with-resources 确保关闭资源
2. 安全注意事项风险 防护措施 实现方式 SQL注入 参数化查询 使用PreparedStatement 敏感数据 字段过滤 在getInsertableFields中过滤 过度更新 更新字段限制 updateSelective方法 权限控制 DAO方法级权限 业务层控制访问 日志泄露 避免记录参数值 关闭敏感日志
3. 扩展性设计
protected String getTableName ( ) { if ( entityClass. isAnnotationPresent ( Table . class ) ) { return entityClass. getAnnotation ( Table . class ) . name ( ) ; } return camelToSnake ( entityClass. getSimpleName ( ) ) ;
}
protected String getPrimaryKeyName ( ) { for ( Field field : entityClass. getDeclaredFields ( ) ) { if ( field. isAnnotationPresent ( Id . class ) ) { return field. getName ( ) ; } } return "id" ;
}
protected Object convertFieldValue ( Field field, Object value) { if ( value instanceof LocalDateTime ) { return Timestamp . valueOf ( ( LocalDateTime ) value) ; } return value;
}
五、高级功能实现
1. 分页查询支持
public Page < T > findPage ( int pageNum, int pageSize) { String tableName = getTableName ( ) ; String sql = "SELECT * FROM " + tableName + " LIMIT ? OFFSET ?" ; int offset = ( pageNum - 1 ) * pageSize; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { pstmt. setInt ( 1 , pageSize) ; pstmt. setInt ( 2 , offset) ; List < T > result = new ArrayList < > ( ) ; try ( ResultSet rs = pstmt. executeQuery ( ) ) { while ( rs. next ( ) ) { result. add ( mapRowToEntity ( rs) ) ; } } int total = getTotalCount ( ) ; return new Page < > ( pageNum, pageSize, total, result) ; } catch ( SQLException e) { throw new RuntimeException ( "Paged query failed" , e) ; }
} private int getTotalCount ( ) { String sql = "SELECT COUNT(*) FROM " + getTableName ( ) ; try ( Connection conn = dataSource. getConnection ( ) ; Statement stmt = conn. createStatement ( ) ; ResultSet rs = stmt. executeQuery ( sql) ) { if ( rs. next ( ) ) { return rs. getInt ( 1 ) ; } return 0 ; } catch ( SQLException e) { throw new RuntimeException ( "Count query failed" , e) ; }
}
2. 乐观锁支持
public class BaseEntity { private Long id; private Integer version;
}
public int updateWithVersion ( T entity) { String tableName = getTableName ( ) ; String primaryKey = getPrimaryKeyName ( ) ; List < Field > fields = getNonEmptyFields ( entity) ; String sql = generateUpdateSql ( tableName, primaryKey, fields) + " AND version = ?" ; try ( Connection conn = dataSource. getConnection ( ) ; PreparedStatement pstmt = conn. prepareStatement ( sql) ) { int paramIndex = setParameters ( pstmt, entity, fields, 1 ) ; Object idValue = ReflectionUtils . getFieldValue ( entity, primaryKey) ; pstmt. setObject ( paramIndex++ , idValue) ; Object version = ReflectionUtils . getFieldValue ( entity, "version" ) ; pstmt. setObject ( paramIndex, version) ; int updated = pstmt. executeUpdate ( ) ; if ( updated == 0 ) { throw new OptimisticLockException ( "Concurrent modification detected" ) ; } setFieldValue ( entity, "version" , ( Integer ) version + 1 ) ; return updated; } catch ( SQLException e) { throw new RuntimeException ( "Update with version failed" , e) ; }
}
3. 多数据源支持
public abstract class AbstractBaseDao < T , ID> { private DataSourceSelector dataSourceSelector; protected Connection getConnection ( ) throws SQLException { return dataSourceSelector. determineDataSource ( ) . getConnection ( ) ; }
}
public interface DataSourceSelector { DataSource determineDataSource ( ) ;
}
public class ThreadLocalDataSourceSelector implements DataSourceSelector { private static final ThreadLocal < DataSource > context = new ThreadLocal < > ( ) ; private final DataSource defaultDataSource; public ThreadLocalDataSourceSelector ( DataSource defaultDataSource) { this . defaultDataSource = defaultDataSource; } @Override public DataSource determineDataSource ( ) { DataSource ds = context. get ( ) ; return ds != null ? ds : defaultDataSource; } public static void setDataSource ( DataSource dataSource) { context. set ( dataSource) ; } public static void clear ( ) { context. remove ( ) ; }
}
六、总结与最佳实践
1. BaseDao 使用场景场景 适用性 建议 小型项目 ⭐⭐⭐⭐⭐ 推荐使用 中型项目 ⭐⭐⭐⭐ 配合MyBatis使用 大型项目 ⭐⭐ 使用JPA/Hibernate 微服务架构 ⭐⭐⭐ 作为仓储层基础
2. 性能优化矩阵操作类型 数据量 优化策略 单条插入 <100 直接插入 批量插入 >100 批处理 全字段更新 任意 update方法 部分更新 任意 updateSelective 高频查询 任意 增加缓存层
3. 实施建议
遵循单一职责原则 :BaseDao只负责通用CRUD,自定义查询在子类实现异常处理 :封装统一的DaoException事务控制 :在Service层管理事务,DAO不处理事务连接管理 :使用连接池并正确配置安全防护 :永远不要拼接SQL字符串版本控制 :为实体添加乐观锁支持日志记录 :记录操作摘要而非参数详情单元测试 :覆盖所有基础CRUD操作
单条CRUD
批量操作
复杂查询
业务层
调用
操作类型
BaseDao
BaseDao批量方法
子类自定义方法
数据库
最佳实践总结 :BaseDao是数据访问层的强大抽象,正确实现可以极大减少重复代码。但在实际项目中,建议优先考虑成熟的ORM框架(如MyBatis、JPA),它们提供了更完善的解决方案和更好的性能优化。BaseDao模式最适合作为学习JDBC原理或小型项目的解决方案。