1. package org.jeecg.modules.ky.common.util;
    2. import lombok.SneakyThrows;
    3. import lombok.extern.slf4j.Slf4j;
    4. import org.springframework.beans.BeansException;
    5. import org.springframework.beans.factory.BeanDefinitionStoreException;
    6. import org.springframework.beans.factory.config.BeanDefinition;
    7. import org.springframework.beans.factory.config.BeanDefinitionHolder;
    8. import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
    9. import org.springframework.beans.factory.support.*;
    10. import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
    11. import org.springframework.core.type.filter.AssignableTypeFilter;
    12. import org.springframework.lang.NonNull;
    13. import org.springframework.lang.Nullable;
    14. import org.springframework.stereotype.Component;
    15. import org.springframework.util.CollectionUtils;
    16. import java.lang.reflect.Modifier;
    17. import java.util.*;
    18. @Slf4j
    19. @Component
    20. public class SpringUtil implements BeanDefinitionRegistryPostProcessor {
    21. private static ConfigurableListableBeanFactory beanFactory;
    22. private static BeanDefinitionRegistry registry;
    23. @Override
    24. public void postProcessBeanFactory(@NonNull ConfigurableListableBeanFactory beanFactory)
    25. throws BeansException {
    26. SpringUtil.beanFactory = beanFactory;
    27. }
    28. @Override
    29. public void postProcessBeanDefinitionRegistry(@NonNull BeanDefinitionRegistry registry)
    30. throws BeansException {
    31. SpringUtil.registry = registry;
    32. }
    33. /**
    34. * 通过扫描类的包路径下的指定子类向系统中注册 bean <br>
    35. * 扫描到的类会走一遍bean的生命周期,其中的依赖也会自动注入
    36. *
    37. * @param packagePath 需要扫描类的包路径
    38. * @return 自动生成的bean名称
    39. */
    40. @SuppressWarnings("UnusedReturnValue")
    41. public static <T> Map<Class<T>, String> registerBeansByParentClass(
    42. Class<T> parentClass, String packagePath) throws BeanDefinitionStoreException {
    43. List<Class<T>> subClasses = ReflectionUtil.getSubClasses(parentClass, packagePath);
    44. if (CollectionUtils.isEmpty(subClasses)) {
    45. return Collections.emptyMap();
    46. }
    47. Map<Class<T>, String> result = new HashMap<>(subClasses.size());
    48. // 获取该包下所有parentClass的子类,初始化bean并注册
    49. for (Class<T> beanClazz : subClasses) {
    50. result.put(beanClazz, registerBean(beanClazz));
    51. }
    52. return result;
    53. }
    54. /**
    55. * 主动向Spring容器中注册bean
    56. *
    57. * @param registry Bean定义注册表
    58. * @param beanName BeanName
    59. * @param aliases 别名
    60. * @param beanClazz 注册的bean的类性
    61. * @param args 构造方法的必要参数,顺序和类型要求和clazz中定义的一致
    62. */
    63. public static void registerBean(
    64. BeanDefinitionRegistry registry,
    65. String beanName,
    66. @Nullable String[] aliases,
    67. Class<?> beanClazz,
    68. Object... args)
    69. throws BeanDefinitionStoreException {
    70. BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz);
    71. if (args != null && args.length > 0) {
    72. for (Object arg : args) {
    73. builder.addConstructorArgValue(arg);
    74. }
    75. }
    76. registerBean(
    77. registry, new BeanDefinitionHolder(builder.getRawBeanDefinition(), beanName, aliases));
    78. }
    79. /**
    80. * 通过自动生成的bean名称注册bean,会自动注入依赖
    81. *
    82. * @param beanClazz 要生成的beanClass
    83. * @return 返回生成的Bean名称 {@link BeanDefinitionReaderUtils#generateBeanName}
    84. */
    85. public static String registerBean(Class<?> beanClazz) throws BeanDefinitionStoreException {
    86. return BeanDefinitionReaderUtils.registerWithGeneratedName(
    87. getAutowireBeanDefinition(beanClazz), registry);
    88. }
    89. /**
    90. * 注册bean,会自动注入依赖
    91. *
    92. * @param beanClazz 要生成的beanClass
    93. * @param beanName Bean名称
    94. * @param aliases 别名
    95. */
    96. public static void registerBean(Class<?> beanClazz, String beanName, String... aliases)
    97. throws BeanDefinitionStoreException {
    98. registerBean(registry, beanClazz, beanName, aliases);
    99. }
    100. /**
    101. * 注册bean,会自动注入依赖
    102. *
    103. * @param registry Bean定义注册表
    104. * @param beanClazz 要生成的beanClass
    105. * @param beanName Bean名称
    106. * @param aliases 别名
    107. */
    108. public static void registerBean(
    109. BeanDefinitionRegistry registry, Class<?> beanClazz, String beanName, String... aliases)
    110. throws BeanDefinitionStoreException {
    111. BeanDefinitionHolder beanDefinitionHolder =
    112. new BeanDefinitionHolder(getAutowireBeanDefinition(beanClazz), beanName, aliases);
    113. registerBean(registry, beanDefinitionHolder);
    114. }
    115. /**
    116. * 注册bean,会自动注入依赖
    117. *
    118. * @param registry Bean定义注册表
    119. * @param definitionHolder 带有名称和别名的Bean定义的持有者
    120. */
    121. public static void registerBean(
    122. BeanDefinitionRegistry registry, BeanDefinitionHolder definitionHolder)
    123. throws BeanDefinitionStoreException {
    124. validateBeanName(definitionHolder);
    125. BeanDefinitionReaderUtils.registerBeanDefinition(definitionHolder, registry);
    126. }
    127. /** 获取bean定义,已经设置好按照类型自动注入依赖 */
    128. private static GenericBeanDefinition getAutowireBeanDefinition(Class<?> beanClazz) {
    129. BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(beanClazz);
    130. builder.setAutowireMode(GenericBeanDefinition.AUTOWIRE_BY_TYPE);
    131. // org.springframework.beans.factory.support.DefaultListableBeanFactory.registerBeanDefinition
    132. builder.setRole(BeanDefinition.ROLE_APPLICATION);
    133. return (GenericBeanDefinition) builder.getRawBeanDefinition();
    134. }
    135. private static void validateBeanName(BeanDefinitionHolder definitionHolder) {
    136. String beanName = definitionHolder.getBeanName();
    137. if (registry.isBeanNameInUse(beanName)) {
    138. if (log.isDebugEnabled()) {
    139. if (registry.isAlias(beanName)) {
    140. log.debug(
    141. "Overriding bean alias with a different definition: replacing One bean alias with ["
    142. + definitionHolder.getBeanDefinition()
    143. + "]");
    144. } else {
    145. log.debug(
    146. "Overriding bean definition for bean '"
    147. + beanName
    148. + "' with a different definition: replacing ["
    149. + registry.getBeanDefinition(beanName)
    150. + "] with ["
    151. + definitionHolder.getBeanDefinition()
    152. + "]");
    153. }
    154. } else {
    155. log.info("Overriding bean definition for bean '" + beanName + "'");
    156. }
    157. }
    158. if (definitionHolder.getAliases() != null) {
    159. List<String> usedAliases = new ArrayList<>(definitionHolder.getAliases().length);
    160. for (String alias : definitionHolder.getAliases()) {
    161. if (registry.isAlias(alias)) {
    162. usedAliases.add(alias);
    163. }
    164. }
    165. if (!usedAliases.isEmpty()) {
    166. if (log.isDebugEnabled()) {
    167. log.debug(
    168. "Overriding bean alias with a different definition: replacing One bean alias "
    169. + Arrays.toString(usedAliases.toArray())
    170. + "with ["
    171. + definitionHolder.getBeanDefinition()
    172. + "]");
    173. }
    174. }
    175. }
    176. }
    177. /**
    178. * 通过父类class和类路径获取该路径下父类的所有子类列表
    179. *
    180. * @param parentClass 父类或接口的class
    181. * @param packagePath 类路径
    182. * @return 所有该类子类或实现类的列表
    183. */
    184. @SneakyThrows(ClassNotFoundException.class)
    185. public static <T> List<Class<T>> getSubClasses(
    186. final Class<T> parentClass, final String packagePath) {
    187. final ClassPathScanningCandidateComponentProvider provider =
    188. new ClassPathScanningCandidateComponentProvider(false);
    189. provider.addIncludeFilter(new AssignableTypeFilter(parentClass));
    190. final Set<BeanDefinition> components = provider.findCandidateComponents(packagePath);
    191. final List<Class<T>> subClasses = new ArrayList<>();
    192. for (final BeanDefinition component : components) {
    193. @SuppressWarnings("unchecked")
    194. final Class<T> cls = (Class<T>) Class.forName(component.getBeanClassName());
    195. if (Modifier.isAbstract(cls.getModifiers())) {
    196. continue;
    197. }
    198. subClasses.add(cls);
    199. }
    200. return subClasses;
    201. }
    202. }