diff --git a/spring-cloud-tencent-commons/src/main/java/com/tencent/cloud/common/util/BeanFactoryUtils.java b/spring-cloud-tencent-commons/src/main/java/com/tencent/cloud/common/util/BeanFactoryUtils.java index 7bfd0c872..9a315a341 100644 --- a/spring-cloud-tencent-commons/src/main/java/com/tencent/cloud/common/util/BeanFactoryUtils.java +++ b/spring-cloud-tencent-commons/src/main/java/com/tencent/cloud/common/util/BeanFactoryUtils.java @@ -24,11 +24,12 @@ import java.util.List; import java.util.stream.Collectors; import org.springframework.beans.factory.BeanFactory; -import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.ListableBeanFactory; + +import static org.springframework.beans.factory.BeanFactoryUtils.beanNamesForTypeIncludingAncestors; /** * the utils for bean factory. - * * @author lepdou 2022-05-23 */ public final class BeanFactoryUtils { @@ -37,13 +38,12 @@ public final class BeanFactoryUtils { } public static List getBeans(BeanFactory beanFactory, Class requiredType) { - if (!(beanFactory instanceof DefaultListableBeanFactory)) { + if (!(beanFactory instanceof ListableBeanFactory)) { throw new RuntimeException("bean factory not support get list bean. factory type = " + beanFactory.getClass() .getName()); } - String[] beanNames = ((DefaultListableBeanFactory) beanFactory).getBeanNamesForType(requiredType); - + String[] beanNames = beanNamesForTypeIncludingAncestors((ListableBeanFactory) beanFactory, requiredType); if (beanNames.length == 0) { return Collections.emptyList(); } diff --git a/spring-cloud-tencent-commons/src/test/java/com/tencent/cloud/common/util/BeanFactoryUtilsTest.java b/spring-cloud-tencent-commons/src/test/java/com/tencent/cloud/common/util/BeanFactoryUtilsTest.java new file mode 100644 index 000000000..d6dd7c25a --- /dev/null +++ b/spring-cloud-tencent-commons/src/test/java/com/tencent/cloud/common/util/BeanFactoryUtilsTest.java @@ -0,0 +1,30 @@ +package com.tencent.cloud.common.util; + +import org.junit.Assert; +import org.junit.Test; + +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; + +/** + * Test for {@link BeanFactoryUtils}. + * + * @author Derek Yi 2022-08-18 + */ +public class BeanFactoryUtilsTest { + + @Test + public void testGetBeansIncludingAncestors() { + DefaultListableBeanFactory parentBeanFactory = new DefaultListableBeanFactory(); + parentBeanFactory.registerBeanDefinition("foo", new RootBeanDefinition(Foo.class)); + + DefaultListableBeanFactory childBeanFactory = new DefaultListableBeanFactory(parentBeanFactory); + Assert.assertTrue(childBeanFactory.getBeansOfType(Foo.class).isEmpty()); + + Assert.assertTrue(BeanFactoryUtils.getBeans(childBeanFactory, Foo.class).size() == 1); + } + + static class Foo { + + } +}