pull/917/head
seanyu 3 years ago
parent 03775cf872
commit c6e368e8bd

@ -72,14 +72,7 @@ public class PolarisCircuitBreakerFallbackFactory implements FallbackFactory {
.status(fallbackInfo.getCode());
if (fallbackInfo.getHeaders() != null) {
Map<String, Collection<String>> headers = new HashMap<>();
fallbackInfo.getHeaders().forEach((k, v) -> {
if (headers.containsKey(k)) {
headers.get(k).add(v);
}
else {
headers.put(k, new HashSet<>(Collections.singleton(v)));
}
});
fallbackInfo.getHeaders().forEach((k, v) -> headers.put(k, Collections.singleton(v)));
responseBuilder.headers(headers);
}
if (fallbackInfo.getBody() != null) {

@ -72,7 +72,7 @@ public class PolarisCircuitBreakerHttpResponse extends AbstractClientHttpRespons
}
@Override
public final String getStatusText() throws IOException {
public final String getStatusText() {
HttpStatus status = HttpStatus.resolve(getRawStatusCode());
return (status != null ? status.getReasonPhrase() : "");
}

@ -20,6 +20,7 @@ package com.tencent.cloud.polaris.circuitbreaker.resttemplate;
import java.io.IOException;
import java.lang.reflect.Method;
import com.tencent.cloud.rpc.enhancement.resttemplate.EnhancedRestTemplateReporter;
import com.tencent.polaris.api.pojo.CircuitBreakerStatus;
import com.tencent.polaris.circuitbreak.client.exception.CallAbortedException;
@ -31,6 +32,7 @@ import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestTemplate;
import static com.tencent.cloud.rpc.enhancement.resttemplate.EnhancedRestTemplateReporter.HEADER_HAS_ERROR;
@ -71,8 +73,12 @@ public class PolarisCircuitBreakerRestTemplateInterceptor implements ClientHttpR
// pre handle response error
// EnhancedRestTemplateReporter always return true,
// so we need to check header set by EnhancedRestTemplateReporter
restTemplate.getErrorHandler().hasError(response);
if (Boolean.parseBoolean(response.getHeaders().getFirst(HEADER_HAS_ERROR))) {
ResponseErrorHandler errorHandler = restTemplate.getErrorHandler();
boolean hasError = errorHandler.hasError(response);
if (errorHandler instanceof EnhancedRestTemplateReporter) {
hasError = Boolean.parseBoolean(response.getHeaders().getFirst(HEADER_HAS_ERROR));
}
if (hasError) {
restTemplate.getErrorHandler().handleError(request.getURI(), request.getMethod(), response);
}
return response;

@ -21,7 +21,10 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
@ -48,16 +51,27 @@ import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.client.loadbalancer.LoadBalanced;
import org.springframework.cloud.openfeign.EnableFeignClients;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.web.client.ExpectedCount;
import org.springframework.test.web.client.MockRestServiceServer;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.DefaultUriBuilderFactory;
import static com.tencent.polaris.test.common.TestUtils.SERVER_ADDRESS_ENV;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.method;
import static org.springframework.test.web.client.match.MockRestRequestMatchers.requestTo;
import static org.springframework.test.web.client.response.MockRestResponseCreators.withStatus;
/**
* @author sean yu
@ -69,7 +83,7 @@ import static org.springframework.boot.test.context.SpringBootTest.WebEnvironmen
"spring.cloud.gateway.enabled=false",
"feign.circuitbreaker.enabled=true",
"spring.cloud.polaris.namespace=default",
"spring.cloud.polaris.service=Test"
"spring.cloud.polaris.service=test"
})
@DirtiesContext
public class PolarisCircuitBreakerRestTemplateIntegrationTest {
@ -85,6 +99,10 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
}
}
@Autowired
@Qualifier("defaultRestTemplate")
private RestTemplate defaultRestTemplate;
@Autowired
@Qualifier("restTemplateFallbackFromPolaris")
private RestTemplate restTemplateFallbackFromPolaris;
@ -97,20 +115,48 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
@Qualifier("restTemplateFallbackFromCode2")
private RestTemplate restTemplateFallbackFromCode2;
@Autowired
@Qualifier("restTemplateFallbackFromCode3")
private RestTemplate restTemplateFallbackFromCode3;
@Autowired
@Qualifier("restTemplateFallbackFromCode4")
private RestTemplate restTemplateFallbackFromCode4;
@Autowired
private ApplicationContext applicationContext;
@Test
public void testRestTemplate() {
assertThat(restTemplateFallbackFromCode.getForObject("/example/service/b/info", String.class)).isEqualTo("\"this is a fallback class\"");
Utils.sleepUninterrupted(2000);
public void testRestTemplate() throws URISyntaxException {
MockRestServiceServer mockServer = MockRestServiceServer.createServer(defaultRestTemplate);
mockServer
.expect(ExpectedCount.once(), requestTo(new URI("http://localhost:18001/example/service/b/info")))
.andExpect(method(HttpMethod.GET))
.andRespond(withStatus(HttpStatus.OK).body("OK"));
assertThat(defaultRestTemplate.getForObject("http://localhost:18001/example/service/b/info", String.class)).isEqualTo("OK");
mockServer.verify();
mockServer.reset();
mockServer
.expect(ExpectedCount.once(), requestTo(new URI("http://localhost:18001/example/service/b/info")))
.andExpect(method(HttpMethod.GET))
.andRespond(withStatus(HttpStatus.BAD_GATEWAY).body("BAD_GATEWAY"));
assertThat(defaultRestTemplate.getForObject("http://localhost:18001/example/service/b/info", String.class)).isEqualTo("fallback");
mockServer.verify();
mockServer.reset();
assertThat(restTemplateFallbackFromCode.getForObject("/example/service/b/info", String.class)).isEqualTo("\"this is a fallback class\"");
Utils.sleepUninterrupted(2000);
assertThat(restTemplateFallbackFromCode2.getForObject("/example/service/b/info", String.class)).isEqualTo("fallback");
assertThat(restTemplateFallbackFromCode2.getForObject("/example/service/b/info", String.class)).isEqualTo("\"this is a fallback class\"");
Utils.sleepUninterrupted(2000);
assertThat(restTemplateFallbackFromCode2.getForObject("/example/service/b/info", String.class)).isEqualTo("fallback");
assertThat(restTemplateFallbackFromCode3.getForEntity("/example/service/b/info", String.class).getStatusCode()).isEqualTo(HttpStatus.OK);
Utils.sleepUninterrupted(2000);
assertThat(restTemplateFallbackFromPolaris.getForObject("/example/service/b/info", String.class)).isEqualTo("\"fallback from polaris server\"");
assertThat(restTemplateFallbackFromCode4.getForObject("/example/service/b/info", String.class)).isEqualTo("fallback");
Utils.sleepUninterrupted(2000);
assertThat(restTemplateFallbackFromPolaris.getForObject("/example/service/b/info", String.class)).isEqualTo("\"fallback from polaris server\"");
// just for code coverage
PolarisCircuitBreakerHttpResponse response = ((CustomPolarisCircuitBreakerFallback) applicationContext.getBean("customPolarisCircuitBreakerFallback")).fallback();
assertThat(response.getStatusText()).isEqualTo("OK");
assertThat(response.getFallbackInfo().getCode()).isEqualTo(200);
}
@Configuration
@ -119,6 +165,12 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
@EnableFeignClients
public static class TestConfig {
@Bean
@PolarisCircuitBreakerRestTemplate(fallback = "fallback")
public RestTemplate defaultRestTemplate() {
return new RestTemplate();
}
@Bean
@LoadBalanced
@PolarisCircuitBreakerRestTemplate
@ -141,7 +193,7 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
@Bean
@LoadBalanced
@PolarisCircuitBreakerRestTemplate(fallback = "fallback")
@PolarisCircuitBreakerRestTemplate(fallbackClass = CustomPolarisCircuitBreakerFallback2.class)
public RestTemplate restTemplateFallbackFromCode2() {
DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory("http://" + TEST_SERVICE_NAME);
RestTemplate restTemplate = new RestTemplate();
@ -149,11 +201,41 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
return restTemplate;
}
@Bean
@LoadBalanced
@PolarisCircuitBreakerRestTemplate(fallbackClass = CustomPolarisCircuitBreakerFallback3.class)
public RestTemplate restTemplateFallbackFromCode3() {
DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory("http://" + TEST_SERVICE_NAME);
RestTemplate restTemplate = new RestTemplate();
restTemplate.setUriTemplateHandler(uriBuilderFactory);
return restTemplate;
}
@Bean
@LoadBalanced
@PolarisCircuitBreakerRestTemplate(fallback = "fallback")
public RestTemplate restTemplateFallbackFromCode4() {
DefaultUriBuilderFactory uriBuilderFactory = new DefaultUriBuilderFactory("http://" + TEST_SERVICE_NAME);
RestTemplate restTemplate = new RestTemplate();
restTemplate.setUriTemplateHandler(uriBuilderFactory);
return restTemplate;
}
@Bean
public CustomPolarisCircuitBreakerFallback customPolarisCircuitBreakerFallback() {
return new CustomPolarisCircuitBreakerFallback();
}
@Bean
public CustomPolarisCircuitBreakerFallback2 customPolarisCircuitBreakerFallback2() {
return new CustomPolarisCircuitBreakerFallback2();
}
@Bean
public CustomPolarisCircuitBreakerFallback3 customPolarisCircuitBreakerFallback3() {
return new CustomPolarisCircuitBreakerFallback3();
}
@Bean
public CircuitBreakAPI circuitBreakAPI() throws InvalidProtocolBufferException {
try {
@ -176,6 +258,22 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
return CircuitBreakAPIFactory.createCircuitBreakAPIByConfig(configuration);
}
@RestController
@RequestMapping("/example/service/b")
public class ServiceBController {
/**
* Get service information.
*
* @return service information
*/
@GetMapping("/info")
public String info() {
return "hello world ! I'm a service B1";
}
}
}
public static class CustomPolarisCircuitBreakerFallback implements PolarisCircuitBreakerFallback {
@ -183,9 +281,31 @@ public class PolarisCircuitBreakerRestTemplateIntegrationTest {
public PolarisCircuitBreakerHttpResponse fallback() {
return new PolarisCircuitBreakerHttpResponse(
200,
new HashMap<String, String>() {{
put("xxx", "xxx");
}},
"\"this is a fallback class\"");
}
}
public static class CustomPolarisCircuitBreakerFallback2 implements PolarisCircuitBreakerFallback {
@Override
public PolarisCircuitBreakerHttpResponse fallback() {
return new PolarisCircuitBreakerHttpResponse(
200,
"\"this is a fallback class\""
);
}
}
public static class CustomPolarisCircuitBreakerFallback3 implements PolarisCircuitBreakerFallback {
@Override
public PolarisCircuitBreakerHttpResponse fallback() {
return new PolarisCircuitBreakerHttpResponse(
200
);
}
}
}

Loading…
Cancel
Save