Spring boot+Mybatis动态多数据源切换


前言

记录自己开发过程的解决方法和实现思路:
最近由于关系的原因需要使用多数据源切换进行不同数据库的数据展示和操作,所以进行了研究,并记录整个过程
我所使用的是Mysql5.7版本Spring boot版本为2.1.6。

实现的原理:

使用注解标识要使用的数据源,然后继承AbstractRoutingDataSource再利用AOP在执行方法前切换数据源,数据源是使用的spring提供的多数据源类(org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource)


对多数据源实现的核心类AbstractRoutingDataSource部分内容解析

DataSource1.png
DataSource2.png

以上为AbstractRoutingDataSource部分源代码,
其中主要使用为以下方法

  1. setTargetDataSources(注入目标数据源)
  2. setDefaultTargetDataSource(注入默认的数据源)
  3. determineCurrentLookupKey(是一个抽象方法,它由具体的子类实现。这个方法的目的是确定当前线程应该使用的数据源的标识。在实际应用中,这个方法通常通过访问线程本地变量或其他上下文信息来获取标识)
  4. setTargetDataSources方法–需要切换的数据源,参数为Map,key是determineCurrentLookupKey方法的返回值,value是数据源实例
  5. setDefaultTargetDataSource方法设置默认使用的数据源,就是没有指定数据源的情况下使用的数据源,参数是一个数据源实例
  6. 需要进行实现的抽象方法determineCurrentLookupKey(),该方法返回需要使用的DataSource的key值,然后根据这个key从resolvedDataSources这个map里取出对应的DataSource,如果找不到,则用默认的resolvedDefaultDataSource。

目录结构如图所示:

![datasource.png]/blog/java/datasource.png)

主要类的说明:

配置文件以及相关类的说明

  1. 数据库配置文件:application.yml;
  2. 数据源注解类:DataSource.java;
  3. 数据源AOP类:DataSourceAspect.java;
  4. 数据源Bean配置类:DataSourceBean.java;
  5. 数据库配置类:DataSourceConfig.java;
  6. 数据源工具:DataSourceUtil.java;
  7. 多数据源类型:DataSourceType.java;
  8. 实现AOP动态切换:DynamicDataSource.java;
  9. 定义上下文数据源:JdbcContextHolder.java;

    启动项目之前需要将自带的数据源进行排除:

    Application.pag

数据库配置文件:application.yml

image-application.png

具体的代码如下:

数据源注解类:DataSource.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* 数据源注解类
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD})
public @interface DataSource {
DataSourceType value() default DataSourceType.Master;
}


数据源AOP类:DataSourceAspect.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

import org.aspectj.lang.JoinPoint;

import org.aspectj.lang.annotation.After;

import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;

/**
* AOP根据注解给上下文赋值
*/
//声明此类为aop
@Aspect
//加载顺序:动态切换数据源的时候,如果事务在前,数据源切换在后,会导致数据源切换失效,所以就用到了Order(排序)这个关键字
// Order值越小,优先级越高!
@Order(1)
//注解表明一个类会作为组件类,并告知Spring要为这个类创建bean
@Component
public class DataSourceAspect {
private Logger logger = LoggerFactory.getLogger(this.getClass());

//切点
@Pointcut("@annotation(com.xiaoming.demo.config.datasource.DataSource)")
public void aspect(){
System.out.println("aspect");
}

/**
* 前置通知, 在方法执行之前执行
* @param joinPoint
*/
@Before("aspect()")
private void before(JoinPoint joinPoint){
Object target = joinPoint.getTarget();
String method = joinPoint.getSignature().getName();
Class<?> classz = target.getClass();
Class<?>[] parameterTypes = ((MethodSignature) joinPoint.getSignature()).getMethod().getParameterTypes();
try {
Method m = classz.getMethod(method,parameterTypes);
if (m != null && m.isAnnotationPresent(DataSource.class)){
DataSource data = m.getAnnotation(DataSource.class);
JdbcContextHolder.putDataSource(data.value().getName());
logger.debug("===============上下文赋值完成:"+data.value().getName());
}else{
JdbcContextHolder.putDataSource(DataSourceType.Master.getName());
logger.debug("===============使用默认数据源:"+DataSourceType.Master.getName());
}
}catch (Exception e){
e.printStackTrace();
}
}

/**
* 后置通知, 在方法执行之后执行
* @param point
*/

@After("@annotation(dataSource)")
public void afterSwitchDS(JoinPoint point,DataSource dataSource) {
logger.info(String.format("当前数据源 %s 执行清理方法", dataSource.value().getName()));
JdbcContextHolder.clearDatabaseSource();
}


}

数据源Bean配置:DataSourceBean.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  
import org.springframework.boot.context.properties.ConfigurationProperties;

import org.springframework.stereotype.Component;

import java.util.Map;

/**
* 数据源实体类
*/

@Component
@ConfigurationProperties(prefix = "spring.datasource") //application.yml中对应属性的前缀
public class DataSourceBean {
/**
主数据源配置
*/
private Map<String,String> master;
/**
*第二数据源配置
*/
private Map<String,String> slave;
public Map<String, String> getMaster() {
return master;
}

public void setMaster(Map<String, String> master) {
this.master = master;
}
public Map<String, String> getSlave() {
return slave;
}
public void setSlave(Map<String, String> slave) {
this.slave = slave;
}
}

数据库配置类:DataSourceConfig.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

import com.alibaba.druid.support.http.StatViewServlet;

import com.alibaba.druid.support.http.WebStatFilter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;

import javax.annotation.Resource;
import javax.sql.DataSource;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* 数据库配置
*/
@SuppressWarnings("AlibabaRemoveCommentedCode")
@Configuration
public class DataSourceConfig {

private Logger logger = LoggerFactory.getLogger(DataSourceConfig.class);


@Resource
private DataSourceBean dataSourceBean;

@Bean(name = "dynamicDataSource")
@Primary //优先使用,多数据源
public DataSource dataSource(){
DynamicDataSource dynamicDataSource = new DynamicDataSource();
//配置多个数据源
Map<Object,Object> map = new HashMap<>();

List<String> fields = DataSourceUtil.getClassFields(DataSourceBean.class);
int i = 0;
for (String field:fields){
Map<String,String> config = null;
try {
config = (Map<String, String>) DataSourceUtil.getFieldValueByName(field,dataSourceBean);
} catch (Exception e) {
e.printStackTrace();
}
if (config == null){
logger.error("数据源配置失败:"+field);
continue;
}
try {
DataSource dataSource = DataSourceUtil.getDataSource(config);
if (i == 0){
logger.debug("设置默认数据源:"+field);
dynamicDataSource.setDefaultTargetDataSource(dataSource);
}
map.put(field,DataSourceUtil.getDataSource(config));
logger.debug("链接数据库:"+field);
i++;
} catch (SQLException e) {
logger.error("druid configuration initialization filter", e);
}
}
logger.debug("共配置了"+i+"个数据源");
dynamicDataSource.setTargetDataSources(map);
return dynamicDataSource;
}

@Bean(name="druidServlet")
public ServletRegistrationBean druidServlet() {
ServletRegistrationBean reg = new ServletRegistrationBean();
reg.setServlet(new StatViewServlet());
reg.addUrlMappings("/druid/*");
reg.addInitParameter("allow", ""); //白名单
return reg;
}

@Bean(name = "filterRegistrationBean")
public FilterRegistrationBean filterRegistrationBean() {
FilterRegistrationBean filterRegistrationBean = new FilterRegistrationBean();
filterRegistrationBean.setFilter(new WebStatFilter());
filterRegistrationBean.addUrlPatterns("/*");
filterRegistrationBean.addInitParameter("exclusions", "*.js,*.gif,*.jpg,*.png,*.css,*.ico,/druid/*");
filterRegistrationBean.addInitParameter("profileEnable", "true");
filterRegistrationBean.addInitParameter("principalCookieName","USER_COOKIE");
filterRegistrationBean.addInitParameter("principalSessionName","USER_SESSION");
filterRegistrationBean.addInitParameter("DruidWebStatFilter","/*");
return filterRegistrationBean;
}
}

数据源工具:DataSourceUtil.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

import com.alibaba.druid.pool.DruidDataSource;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* 数据源工具
*/
public class DataSourceUtil {

/**
* 获取指定类的成员变量
* @param clazz
* @return 成员变量名的List
*/
public static List<String> getClassFields(Class clazz){
List<String> list = new ArrayList<>();
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields){
list.add(field.getName());
}
return list;
}

/**
* 依据成员变量获取值
* @param fieldName 变量名
* @param o 已注入的实体
* @return Object
* @throws Exception 抛出异常
*/
public static Object getFieldValueByName(String fieldName, Object o) throws Exception{
String firstLetter = fieldName.substring(0, 1).toUpperCase();
String getter = "get" + firstLetter + fieldName.substring(1);
Method method = o.getClass().getMethod(getter, new Class[] {});
Object value = method.invoke(o, new Object[] {});
return value;
}

/**
* 依据数据配置 获取datasource 对象
* @param params Map 数据配置
* @return 返回datasource
* @throws SQLException 抛出Sql 异常
*/
public static DataSource getDataSource(Map<String,String> params) throws SQLException {
DruidDataSource datasource = new DruidDataSource();
datasource.setUrl(params.get("url"));
datasource.setUsername(params.get("username"));
datasource.setPassword(params.get("password"));
datasource.setDriverClassName(params.get("driverClassName"));
if (params.containsKey("initialSize")) {
datasource.setInitialSize(Integer.parseInt(params.get("initialSize")));
}
if (params.containsKey("minIdle")) {
datasource.setMinIdle(Integer.parseInt(params.get("minIdle")));
}
if (params.containsKey("maxActive")) {
datasource.setMaxActive(Integer.parseInt(params.get("maxActive")));
}
if (params.containsKey("maxWait")){
datasource.setMaxWait(Long.parseLong(params.get("maxWait")));
}
if (params.containsKey("timeBetweenEvictionRunsMillis")){
datasource.setTimeBetweenEvictionRunsMillis(Long.parseLong(params.get("timeBetweenEvictionRunsMillis")));
}
if (params.containsKey("minEvictableIdleTimeMillis")){
datasource.setMinEvictableIdleTimeMillis(Long.parseLong(params.get("minEvictableIdleTimeMillis")));
}
if (params.containsKey("validationQuery")){
datasource.setValidationQuery(params.get("validationQuery"));
}
if (params.containsKey("testWhileIdle")){
datasource.setTestWhileIdle(Boolean.parseBoolean(params.get("testWhileIdle")));
}
if (params.containsKey("testOnBorrow")){
datasource.setTestOnBorrow(Boolean.parseBoolean(params.get("testOnBorrow")));
}
if (params.containsKey("testOnReturn")){
datasource.setTestOnBorrow(Boolean.parseBoolean(params.get("testOnReturn")));
}
if (params.containsKey("poolPreparedStatements")){
datasource.setPoolPreparedStatements(Boolean.parseBoolean(params.get("poolPreparedStatements")));
}
if (params.containsKey("maxPoolPreparedStatementPerConnectionSize")){
datasource.setMaxPoolPreparedStatementPerConnectionSize(
Integer.parseInt(params.get("maxPoolPreparedStatementPerConnectionSize")));
}
if (params.containsKey("filters")){
datasource.setFilters(params.get("filters"));
}
if (params.containsKey("connectionProperties")){
datasource.setConnectionProperties(params.get("connectionProperties"));
}
return datasource;
}
}

多数据源类型:DataSourceType.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/**
* 多数据源类型
*/
public enum DataSourceType {
Master("master"),
Slave("slave");

private String name;

DataSourceType(String name) {
this.name = name;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}
}

实现AOP动态切换:DynamicDataSource.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;

/**
* AbstractRoutingDataSource实现类DynamicDataSource
* 实现AOP动态切换的关键
*/
public class DynamicDataSource extends AbstractRoutingDataSource {
@Override
protected Object determineCurrentLookupKey() {
String dbName = JdbcContextHolder.getDataSource();
if (dbName == null ){
dbName = DataSourceType.Master.getName();
}
logger.debug("数据源为:"+dbName);
return dbName;
}
}

定义上下文数据源:JdbcContextHolder.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


/**
* 通过ThreadLocal定义上下文数据源标识
*/
public class JdbcContextHolder {

// ThreadLocal每个线程都独有的保存其线程所属的变量值
private final static ThreadLocal<String> local = new ThreadLocal<>();

//赋值
public static void putDataSource(String name){
local.set(name);
}

//取值
public static String getDataSource(){
return local.get();
}
//清除
public static void clearDatabaseSource() {
local.remove();
}
}

数据源配置已经完成,如果在代码中进行使用如图
Service.png
数据源的切换我选择在Service层进行,如果在controller进行切换的话,单一不灵活,而选择在Service层进行,只需要在controller里面注入多个service实现即可,如果开启了事务的话,需要在事务执行之前将数据进行切换,否则会出现数据源切换失败,原因:事务本身也是通过AOP配置的,因为它先走了事务切面,在事务还未结束的时候去切换数据源的话会出错的,设置一下切换数据源的AOP的优先级,确保在事务执行之前就已经切换数据源。进行测试,没有发现问题。
数据源动态切换已经完成。

结语

此篇内容仅仅只是记录自己在完成该功能的时候的一些过程,以及自己的思路。其中也遇到一些问题,通过网上查找了相关一些资料,如有不对的地方或者待完善的地方,请多多指教。

思考

如果有多个数据源,需要同时执行DML的时候,此时如果出现异常的话,事务该如何保证数据源都能回滚?

-------------本文结束感谢您的阅读-------------