use arguments instead of label in ratelimit and add caller info in ratelimit

pull/902/head
seanyu 3 years ago
parent 2d9afcc4fb
commit 22ece0aba5

@ -41,6 +41,7 @@ import org.springframework.web.server.WebFilterChain;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
/**
* Filter used for storing the metadata from upstream temporarily when web application is
@ -69,10 +70,10 @@ public class DecodeTransferMetadataReactiveFilter implements WebFilter, Ordered
mergedTransitiveMetadata.putAll(internalTransitiveMetadata);
mergedTransitiveMetadata.putAll(customTransitiveMetadata);
Map<String, String> internalDisposableMetadata = getIntervalMetadata(serverHttpRequest, CUSTOM_DISPOSABLE_METADATA);
Map<String, String> mergedDisposableMetadata = new HashMap<>(internalDisposableMetadata);
Map<String, String> internalCustomDisposableMetadata = getIntervalMetadata(serverHttpRequest, CUSTOM_DISPOSABLE_METADATA);
Map<String, String> internalDefaultDisposableMetadata = getIntervalMetadata(serverHttpRequest, DEFAULT_DISPOSABLE_METADATA);
MetadataContextHolder.init(mergedTransitiveMetadata, mergedDisposableMetadata);
MetadataContextHolder.init(mergedTransitiveMetadata, internalCustomDisposableMetadata, internalDefaultDisposableMetadata);
// Save to ServerWebExchange.
serverWebExchange.getAttributes().put(

@ -43,6 +43,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
/**
* Filter used for storing the metadata from upstream temporarily when web application is
@ -66,10 +67,10 @@ public class DecodeTransferMetadataServletFilter extends OncePerRequestFilter {
mergedTransitiveMetadata.putAll(internalTransitiveMetadata);
mergedTransitiveMetadata.putAll(customTransitiveMetadata);
Map<String, String> internalDisposableMetadata = getInternalMetadata(httpServletRequest, CUSTOM_DISPOSABLE_METADATA);
Map<String, String> mergedDisposableMetadata = new HashMap<>(internalDisposableMetadata);
Map<String, String> internalCustomDisposableMetadata = getInternalMetadata(httpServletRequest, CUSTOM_DISPOSABLE_METADATA);
Map<String, String> internalDefaultDisposableMetadata = getInternalMetadata(httpServletRequest, DEFAULT_DISPOSABLE_METADATA);
MetadataContextHolder.init(mergedTransitiveMetadata, mergedDisposableMetadata);
MetadataContextHolder.init(mergedTransitiveMetadata, internalCustomDisposableMetadata, internalDefaultDisposableMetadata);
TransHeadersTransfer.transfer(httpServletRequest);

@ -25,6 +25,7 @@ import com.tencent.cloud.common.constant.MetadataConstant;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.common.util.JacksonUtils;
import com.tencent.cloud.metadata.util.DefaultTransferMedataUtils;
import feign.RequestInterceptor;
import feign.RequestTemplate;
import org.slf4j.Logger;
@ -37,6 +38,7 @@ import org.springframework.web.client.RestTemplate;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
import static java.net.URLEncoder.encode;
/**
@ -61,7 +63,12 @@ public class EncodeTransferMedataFeignInterceptor implements RequestInterceptor,
Map<String, String> customMetadata = metadataContext.getCustomMetadata();
Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata();
Map<String, String> transHeaders = metadataContext.getTransHeadersKV();
Map<String, String> defaultMetadata = DefaultTransferMedataUtils.getDefaultTransferMedata();
// build default disposable metadata request header
this.buildMetadataHeader(requestTemplate, defaultMetadata, DEFAULT_DISPOSABLE_METADATA);
// build custom disposable metadata request header
this.buildMetadataHeader(requestTemplate, disposableMetadata, CUSTOM_DISPOSABLE_METADATA);
// process custom metadata

@ -27,6 +27,7 @@ import com.tencent.cloud.common.constant.MetadataConstant;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.common.util.JacksonUtils;
import com.tencent.cloud.metadata.util.DefaultTransferMedataUtils;
import org.springframework.core.Ordered;
import org.springframework.http.HttpRequest;
@ -39,6 +40,7 @@ import org.springframework.util.CollectionUtils;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
/**
* Interceptor used for adding the metadata in http headers from context when web client
@ -61,6 +63,10 @@ public class EncodeTransferMedataRestTemplateInterceptor implements ClientHttpRe
Map<String, String> customMetadata = metadataContext.getCustomMetadata();
Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata();
Map<String, String> transHeaders = metadataContext.getTransHeadersKV();
Map<String, String> defaultMetadata = DefaultTransferMedataUtils.getDefaultTransferMedata();
// build default disposable metadata request header
this.buildMetadataHeader(httpRequest, defaultMetadata, DEFAULT_DISPOSABLE_METADATA);
// build custom disposable metadata request header
this.buildMetadataHeader(httpRequest, disposableMetadata, CUSTOM_DISPOSABLE_METADATA);

@ -26,6 +26,7 @@ import com.tencent.cloud.common.constant.MetadataConstant;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.common.util.JacksonUtils;
import com.tencent.cloud.metadata.util.DefaultTransferMedataUtils;
import reactor.core.publisher.Mono;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
import static org.springframework.cloud.gateway.filter.ReactiveLoadBalancerClientFilter.LOAD_BALANCER_CLIENT_FILTER_ORDER;
/**
@ -67,9 +69,11 @@ public class EncodeTransferMedataScgFilter implements GlobalFilter, Ordered {
Map<String, String> customMetadata = metadataContext.getCustomMetadata();
Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata();
Map<String, String> defaultMetadata = DefaultTransferMedataUtils.getDefaultTransferMedata();
this.buildMetadataHeader(builder, customMetadata, CUSTOM_METADATA);
this.buildMetadataHeader(builder, disposableMetadata, CUSTOM_DISPOSABLE_METADATA);
this.buildMetadataHeader(builder, defaultMetadata, DEFAULT_DISPOSABLE_METADATA);
TransHeadersTransfer.transfer(exchange.getRequest());

@ -0,0 +1,23 @@
package com.tencent.cloud.metadata.util;
import java.util.HashMap;
import java.util.Map;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAME;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE;
import static com.tencent.cloud.common.metadata.MetadataContext.LOCAL_NAMESPACE;
import static com.tencent.cloud.common.metadata.MetadataContext.LOCAL_SERVICE;
public final class DefaultTransferMedataUtils {
private DefaultTransferMedataUtils() {
}
public static Map<String, String> getDefaultTransferMedata() {
return new HashMap<String, String>() {{
put(DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE, LOCAL_NAMESPACE);
put(DEFAULT_METADATA_SOURCE_SERVICE_NAME, LOCAL_SERVICE);
}};
}
}

@ -121,7 +121,7 @@ public class PolarisCircuitBreakerMockServerTest {
}
}, t -> "fallback");
resList.add(res);
Utils.sleepUninterrupted(1000);
Utils.sleepUninterrupted(2000);
}
assertThat(resList).isEqualTo(Arrays.asList("invoke success", "fallback", "fallback", "fallback", "fallback"));

@ -19,6 +19,11 @@
<groupId>com.tencent.cloud</groupId>
<artifactId>spring-cloud-tencent-polaris-context</artifactId>
</dependency>
<dependency>
<groupId>com.tencent.cloud</groupId>
<artifactId>spring-cloud-starter-tencent-metadata-transfer</artifactId>
</dependency>
<!-- Spring Cloud Tencent dependencies end -->
<!-- Polaris dependencies start -->

@ -36,6 +36,7 @@ import org.springframework.util.CollectionUtils;
*
*@author lepdou 2022-05-13
*/
@Deprecated
public class RateLimitRuleLabelResolver {
private final ServiceRuleManager serviceRuleManager;

@ -20,10 +20,10 @@ package com.tencent.cloud.polaris.ratelimit.config;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.context.config.PolarisContextAutoConfiguration;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckReactiveFilter;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
@ -31,6 +31,7 @@ import com.tencent.polaris.client.api.SDKContext;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.factory.LimitAPIFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
@ -39,6 +40,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable;
import static com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant.FILTER_ORDER;
import static com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter.QUOTA_FILTER_BEAN_NAME;
import static javax.servlet.DispatcherType.ASYNC;
import static javax.servlet.DispatcherType.ERROR;
@ -62,11 +64,6 @@ public class PolarisRateLimitAutoConfiguration {
return LimitAPIFactory.createLimitAPIByContext(polarisContext);
}
@Bean
public RateLimitRuleLabelResolver rateLimitRuleLabelService(ServiceRuleManager serviceRuleManager) {
return new RateLimitRuleLabelResolver(serviceRuleManager);
}
/**
* Create when web application type is SERVLET.
*/
@ -74,15 +71,19 @@ public class PolarisRateLimitAutoConfiguration {
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET)
protected static class QuotaCheckFilterConfig {
@Bean
public RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver(ServiceRuleManager serviceRuleManager,
@Autowired(required = false) PolarisRateLimiterLabelServletResolver labelResolver) {
return new RateLimitRuleArgumentServletResolver(serviceRuleManager, labelResolver);
}
@Bean
@ConditionalOnMissingBean
public QuotaCheckServletFilter quotaCheckFilter(LimitAPI limitAPI,
@Nullable PolarisRateLimiterLabelServletResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
return new QuotaCheckServletFilter(limitAPI, labelResolver,
polarisRateLimitProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver,
@Autowired(required = false) PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
return new QuotaCheckServletFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentResolver, polarisRateLimiterLimitedFallback);
}
@Bean
@ -92,9 +93,11 @@ public class PolarisRateLimitAutoConfiguration {
quotaCheckServletFilter);
registrationBean.setDispatcherTypes(ASYNC, ERROR, FORWARD, INCLUDE, REQUEST);
registrationBean.setName(QUOTA_FILTER_BEAN_NAME);
registrationBean.setOrder(RateLimitConstant.FILTER_ORDER);
registrationBean.setOrder(FILTER_ORDER);
return registrationBean;
}
}
/**
@ -104,14 +107,18 @@ public class PolarisRateLimitAutoConfiguration {
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.REACTIVE)
protected static class MetadataReactiveFilterConfig {
@Bean
public RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver(ServiceRuleManager serviceRuleManager,
@Autowired(required = false) PolarisRateLimiterLabelReactiveResolver labelResolver) {
return new RateLimitRuleArgumentReactiveResolver(serviceRuleManager, labelResolver);
}
@Bean
public QuotaCheckReactiveFilter quotaCheckReactiveFilter(LimitAPI limitAPI,
@Nullable PolarisRateLimiterLabelReactiveResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver,
RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
return new QuotaCheckReactiveFilter(limitAPI, labelResolver,
polarisRateLimitProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
return new QuotaCheckReactiveFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentResolver, polarisRateLimiterLimitedFallback);
}
}
}

@ -20,27 +20,22 @@ package com.tencent.cloud.polaris.ratelimit.filter;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.annotation.PostConstruct;
import com.google.common.collect.Maps;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.expresstion.SpringWebExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.cloud.polaris.ratelimit.utils.QuotaCheckUtils;
import com.tencent.cloud.polaris.ratelimit.utils.RateLimitUtils;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResultCode;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
@ -54,8 +49,6 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import static com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant.LABEL_METHOD;
/**
* Reactive filter to check quota.
*
@ -67,11 +60,9 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
private final LimitAPI limitAPI;
private final PolarisRateLimiterLabelReactiveResolver labelResolver;
private final PolarisRateLimitProperties polarisRateLimitProperties;
private final RateLimitRuleLabelResolver rateLimitRuleLabelResolver;
private final RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver;
private final PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
@ -79,14 +70,12 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
private String rejectTips;
public QuotaCheckReactiveFilter(LimitAPI limitAPI,
PolarisRateLimiterLabelReactiveResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver,
RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
this.limitAPI = limitAPI;
this.labelResolver = labelResolver;
this.polarisRateLimitProperties = polarisRateLimitProperties;
this.rateLimitRuleLabelResolver = rateLimitRuleLabelResolver;
this.rateLimitRuleArgumentResolver = rateLimitRuleArgumentResolver;
this.polarisRateLimiterLimitedFallback = polarisRateLimiterLimitedFallback;
}
@ -105,12 +94,12 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
String localNamespace = MetadataContext.LOCAL_NAMESPACE;
String localService = MetadataContext.LOCAL_SERVICE;
Map<String, String> labels = getRequestLabels(exchange, localNamespace, localService);
Set<Argument> arguments = rateLimitRuleArgumentResolver.getArguments(exchange, localNamespace, localService);
long waitMs = -1;
try {
String path = exchange.getRequest().getURI().getPath();
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI,
localNamespace, localService, 1, labels, path);
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(
limitAPI, localNamespace, localService, 1, arguments, path);
if (quotaResponse.getCode() == QuotaResultCode.QuotaResultLimited) {
ServerHttpResponse response = exchange.getResponse();
@ -149,40 +138,4 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
}
}
private Map<String, String> getRequestLabels(ServerWebExchange exchange, String localNamespace, String localService) {
Map<String, String> labels = new HashMap<>();
// add build in labels
String path = exchange.getRequest().getURI().getPath();
if (StringUtils.isNotBlank(path)) {
labels.put(LABEL_METHOD, path);
}
// add rule expression labels
Map<String, String> expressionLabels = getRuleExpressionLabels(exchange, localNamespace, localService);
labels.putAll(expressionLabels);
// add custom labels
Map<String, String> customResolvedLabels = getCustomResolvedLabels(exchange);
labels.putAll(customResolvedLabels);
return labels;
}
private Map<String, String> getCustomResolvedLabels(ServerWebExchange exchange) {
if (labelResolver != null) {
try {
return labelResolver.resolve(exchange);
}
catch (Throwable e) {
LOGGER.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Maps.newHashMap();
}
private Map<String, String> getRuleExpressionLabels(ServerWebExchange exchange, String namespace, String service) {
Set<String> expressionLabels = rateLimitRuleLabelResolver.getExpressionLabelKeys(namespace, service);
return SpringWebExpressionLabelUtils.resolve(exchange, expressionLabels);
}
}

@ -19,9 +19,6 @@
package com.tencent.cloud.polaris.ratelimit.filter;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
@ -32,18 +29,16 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.expresstion.ServletExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.cloud.polaris.ratelimit.utils.QuotaCheckUtils;
import com.tencent.cloud.polaris.ratelimit.utils.RateLimitUtils;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResultCode;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -53,8 +48,6 @@ import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.web.filter.OncePerRequestFilter;
import static com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant.LABEL_METHOD;
/**
* Servlet filter to check quota.
*
@ -70,25 +63,21 @@ public class QuotaCheckServletFilter extends OncePerRequestFilter {
private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class);
private final LimitAPI limitAPI;
private final PolarisRateLimiterLabelServletResolver labelResolver;
private final PolarisRateLimitProperties polarisRateLimitProperties;
private final RateLimitRuleLabelResolver rateLimitRuleLabelResolver;
private final RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver;
private final PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
private String rejectTips;
public QuotaCheckServletFilter(LimitAPI limitAPI,
PolarisRateLimiterLabelServletResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver,
RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
this.limitAPI = limitAPI;
this.labelResolver = labelResolver;
this.polarisRateLimitProperties = polarisRateLimitProperties;
this.rateLimitRuleLabelResolver = rateLimitRuleLabelResolver;
this.rateLimitRuleArgumentResolver = rateLimitRuleArgumentResolver;
this.polarisRateLimiterLimitedFallback = polarisRateLimiterLimitedFallback;
}
@ -104,11 +93,11 @@ public class QuotaCheckServletFilter extends OncePerRequestFilter {
String localNamespace = MetadataContext.LOCAL_NAMESPACE;
String localService = MetadataContext.LOCAL_SERVICE;
Map<String, String> labels = getRequestLabels(request, localNamespace, localService);
Set<Argument> arguments = rateLimitRuleArgumentResolver.getArguments(request, localNamespace, localService);
try {
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI,
localNamespace, localService, 1, labels, request.getRequestURI());
localNamespace, localService, 1, arguments, request.getRequestURI());
if (quotaResponse.getCode() == QuotaResultCode.QuotaResultLimited) {
if (!Objects.isNull(polarisRateLimiterLimitedFallback)) {
@ -140,40 +129,4 @@ public class QuotaCheckServletFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
}
private Map<String, String> getRequestLabels(HttpServletRequest request, String localNamespace, String localService) {
Map<String, String> labels = new HashMap<>();
// add build in labels
String path = request.getRequestURI();
if (StringUtils.isNotBlank(path)) {
labels.put(LABEL_METHOD, path);
}
// add rule expression labels
Map<String, String> expressionLabels = getRuleExpressionLabels(request, localNamespace, localService);
labels.putAll(expressionLabels);
// add custom resolved labels
Map<String, String> customLabels = getCustomResolvedLabels(request);
labels.putAll(customLabels);
return labels;
}
private Map<String, String> getCustomResolvedLabels(HttpServletRequest request) {
if (labelResolver != null) {
try {
return labelResolver.resolve(request);
}
catch (Throwable e) {
LOG.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Collections.emptyMap();
}
private Map<String, String> getRuleExpressionLabels(HttpServletRequest request, String namespace, String service) {
Set<String> expressionLabels = rateLimitRuleLabelResolver.getExpressionLabelKeys(namespace, service);
return ServletExpressionLabelUtils.resolve(request, expressionLabels);
}
}

@ -0,0 +1,124 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.web.server.ServerWebExchange;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAME;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE;
/**
* resolve arguments from rate limit rule for Reactive.
*
* @author seansyyu 2023-03-09
*/
public class RateLimitRuleArgumentReactiveResolver {
private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class);
private final ServiceRuleManager serviceRuleManager;
private final PolarisRateLimiterLabelReactiveResolver labelResolver;
public RateLimitRuleArgumentReactiveResolver(ServiceRuleManager serviceRuleManager, PolarisRateLimiterLabelReactiveResolver labelResolver) {
this.serviceRuleManager = serviceRuleManager;
this.labelResolver = labelResolver;
}
public Set<Argument> getArguments(ServerWebExchange request, String namespace, String service) {
RateLimitProto.RateLimit rateLimitRule = serviceRuleManager.getServiceRateLimitRule(namespace, service);
if (rateLimitRule == null) {
return Collections.emptySet();
}
List<RateLimitProto.Rule> rules = rateLimitRule.getRulesList();
if (CollectionUtils.isEmpty(rules)) {
return Collections.emptySet();
}
return rules.stream()
.flatMap(rule -> rule.getArgumentsList().stream())
.map(matchArgument -> {
String matchKey = matchArgument.getKey();
Argument argument = null;
switch (matchArgument.getType()) {
case CUSTOM:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildCustom(matchKey, Optional.ofNullable(getCustomResolvedLabels(request).get(matchKey)).orElse(StringUtils.EMPTY));
break;
case METHOD:
argument = Argument.buildMethod(request.getRequest().getMethodValue());
break;
case HEADER:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildHeader(matchKey, Optional.ofNullable(request.getRequest().getHeaders().getFirst(matchKey)).orElse(StringUtils.EMPTY));
break;
case QUERY:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildQuery(matchKey, Optional.ofNullable(request.getRequest().getQueryParams().getFirst(matchKey)).orElse(StringUtils.EMPTY));
break;
case CALLER_SERVICE:
String sourceServiceNamespace = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE, true).orElse(StringUtils.EMPTY);
String sourceServiceName = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAME, true).orElse(StringUtils.EMPTY);
if (!StringUtils.isEmpty(sourceServiceNamespace) && !StringUtils.isEmpty(sourceServiceName)) {
argument = Argument.buildCallerService(sourceServiceNamespace, sourceServiceName);
}
break;
case CALLER_IP:
InetSocketAddress remoteAddress = request.getRequest().getRemoteAddress();
argument = Argument.buildCallerIP(remoteAddress != null ? remoteAddress.getAddress().getHostAddress() : StringUtils.EMPTY);
break;
default:
break;
}
return argument;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
private Map<String, String> getCustomResolvedLabels(ServerWebExchange request) {
if (labelResolver != null) {
try {
return labelResolver.resolve(request);
}
catch (Throwable e) {
LOG.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Collections.emptyMap();
}
}

@ -0,0 +1,123 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAME;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE;
/**
* resolve arguments from rate limit rule for Servlet.
*
* @author seansyyu 2023-03-09
*/
public class RateLimitRuleArgumentServletResolver {
private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class);
private final ServiceRuleManager serviceRuleManager;
private final PolarisRateLimiterLabelServletResolver labelResolver;
public RateLimitRuleArgumentServletResolver(ServiceRuleManager serviceRuleManager, PolarisRateLimiterLabelServletResolver labelResolver) {
this.serviceRuleManager = serviceRuleManager;
this.labelResolver = labelResolver;
}
public Set<Argument> getArguments(HttpServletRequest request, String namespace, String service) {
RateLimitProto.RateLimit rateLimitRule = serviceRuleManager.getServiceRateLimitRule(namespace, service);
if (rateLimitRule == null) {
return Collections.emptySet();
}
List<RateLimitProto.Rule> rules = rateLimitRule.getRulesList();
if (CollectionUtils.isEmpty(rules)) {
return Collections.emptySet();
}
return rules.stream()
.flatMap(rule -> rule.getArgumentsList().stream())
.map(matchArgument -> {
String matchKey = matchArgument.getKey();
Argument argument = null;
switch (matchArgument.getType()) {
case CUSTOM:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildCustom(matchKey, Optional.ofNullable(getCustomResolvedLabels(request).get(matchKey)).orElse(StringUtils.EMPTY));
break;
case METHOD:
argument = Argument.buildMethod(request.getMethod());
break;
case HEADER:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildHeader(matchKey, Optional.ofNullable(request.getHeader(matchKey)).orElse(StringUtils.EMPTY));
break;
case QUERY:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildQuery(matchKey, Optional.ofNullable(request.getParameter(matchKey)).orElse(StringUtils.EMPTY));
break;
case CALLER_SERVICE:
String sourceServiceNamespace = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE, true).orElse(StringUtils.EMPTY);
String sourceServiceName = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAME, true).orElse(StringUtils.EMPTY);
if (!StringUtils.isEmpty(sourceServiceNamespace) && !StringUtils.isEmpty(sourceServiceName)) {
argument = Argument.buildCallerService(sourceServiceNamespace, sourceServiceName);
}
break;
case CALLER_IP:
argument = Argument.buildCallerIP(Optional.ofNullable(request.getRemoteAddr()).orElse(StringUtils.EMPTY));
break;
default:
break;
}
return argument;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
private Map<String, String> getCustomResolvedLabels(HttpServletRequest request) {
if (labelResolver != null) {
try {
return labelResolver.resolve(request);
}
catch (Throwable e) {
LOG.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Collections.emptyMap();
}
}

@ -18,9 +18,11 @@
package com.tencent.cloud.polaris.ratelimit.utils;
import java.util.Map;
import java.util.Set;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import org.slf4j.Logger;
@ -38,6 +40,7 @@ public final class QuotaCheckUtils {
private QuotaCheckUtils() {
}
@Deprecated
public static QuotaResponse getQuota(LimitAPI limitAPI, String namespace, String service, int count,
Map<String, String> labels, String method) {
// build quota request
@ -56,4 +59,23 @@ public final class QuotaCheckUtils {
return new QuotaResponse(new QuotaResult(QuotaResult.Code.QuotaResultOk, 0, "get quota failed"));
}
}
public static QuotaResponse getQuota(LimitAPI limitAPI, String namespace, String service, int count,
Set<Argument> arguments, String method) {
// build quota request
QuotaRequest quotaRequest = new QuotaRequest();
quotaRequest.setNamespace(namespace);
quotaRequest.setService(service);
quotaRequest.setCount(count);
quotaRequest.setArguments(arguments);
quotaRequest.setMethod(method);
try {
return limitAPI.getQuota(quotaRequest);
}
catch (Throwable throwable) {
LOG.error("fail to invoke getQuota of LimitAPI with QuotaRequest[{}].", quotaRequest, throwable);
return new QuotaResponse(new QuotaResult(QuotaResult.Code.QuotaResultOk, 0, "get quota failed"));
}
}
}

@ -18,9 +18,10 @@
package com.tencent.cloud.polaris.ratelimit.config;
import com.tencent.cloud.polaris.context.config.PolarisContextAutoConfiguration;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckReactiveFilter;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import org.junit.jupiter.api.Test;
@ -55,7 +56,8 @@ public class PolarisRateLimitAutoConfigurationTest {
PolarisRateLimitAutoConfiguration.class))
.run(context -> {
assertThat(context).hasSingleBean(LimitAPI.class);
assertThat(context).hasSingleBean(RateLimitRuleLabelResolver.class);
assertThat(context).doesNotHaveBean(RateLimitRuleArgumentServletResolver.class);
assertThat(context).doesNotHaveBean(RateLimitRuleArgumentReactiveResolver.class);
assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class);
assertThat(context).doesNotHaveBean(QuotaCheckServletFilter.class);
assertThat(context).doesNotHaveBean(FilterRegistrationBean.class);
@ -72,12 +74,13 @@ public class PolarisRateLimitAutoConfigurationTest {
PolarisRateLimitAutoConfiguration.class))
.run(context -> {
assertThat(context).hasSingleBean(LimitAPI.class);
assertThat(context).hasSingleBean(RateLimitRuleLabelResolver.class);
assertThat(context).hasSingleBean(RateLimitRuleArgumentServletResolver.class);
assertThat(context).hasSingleBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class);
assertThat(context).hasSingleBean(QuotaCheckServletFilter.class);
assertThat(context).hasSingleBean(FilterRegistrationBean.class);
assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.MetadataReactiveFilterConfig.class);
assertThat(context).doesNotHaveBean(QuotaCheckReactiveFilter.class);
assertThat(context).doesNotHaveBean(RateLimitRuleArgumentReactiveResolver.class);
});
}
@ -89,7 +92,8 @@ public class PolarisRateLimitAutoConfigurationTest {
PolarisRateLimitAutoConfiguration.class))
.run(context -> {
assertThat(context).hasSingleBean(LimitAPI.class);
assertThat(context).hasSingleBean(RateLimitRuleLabelResolver.class);
assertThat(context).doesNotHaveBean(RateLimitRuleArgumentServletResolver.class);
assertThat(context).hasSingleBean(RateLimitRuleArgumentReactiveResolver.class);
assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class);
assertThat(context).doesNotHaveBean(QuotaCheckServletFilter.class);
assertThat(context).doesNotHaveBean(FilterRegistrationBean.class);

@ -17,32 +17,32 @@
package com.tencent.cloud.polaris.ratelimit.filter;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.ApplicationContextAwareUtils;
import com.tencent.cloud.common.util.expresstion.SpringWebExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
@ -60,10 +60,8 @@ import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
/**
@ -76,36 +74,15 @@ import static org.mockito.Mockito.when;
@SpringBootTest(classes = QuotaCheckReactiveFilterTest.TestApplication.class,
properties = {"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"})
public class QuotaCheckReactiveFilterTest {
private static MockedStatic<ApplicationContextAwareUtils> mockedApplicationContextAwareUtils;
private static MockedStatic<SpringWebExpressionLabelUtils> expressionLabelUtilsMockedStatic;
private final PolarisRateLimiterLabelReactiveResolver labelResolver =
exchange -> Collections.singletonMap("ReactiveResolver", "ReactiveResolver");
exchange -> Collections.singletonMap("xxx", "xxx");
private QuotaCheckReactiveFilter quotaCheckReactiveFilter;
private QuotaCheckReactiveFilter quotaCheckWithRateLimiterLimitedFallbackReactiveFilter;
private PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
@BeforeAll
static void beforeAll() {
expressionLabelUtilsMockedStatic = mockStatic(SpringWebExpressionLabelUtils.class);
when(SpringWebExpressionLabelUtils.resolve(any(ServerWebExchange.class), anySet()))
.thenReturn(Collections.singletonMap("RuleLabelResolver", "RuleLabelResolver"));
mockedApplicationContextAwareUtils = Mockito.mockStatic(ApplicationContextAwareUtils.class);
mockedApplicationContextAwareUtils.when(() -> ApplicationContextAwareUtils.getProperties(anyString()))
.thenReturn("unit-test");
}
@AfterAll
static void afterAll() {
mockedApplicationContextAwareUtils.close();
expressionLabelUtilsMockedStatic.close();
}
@BeforeEach
void setUp() {
void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST";
polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
LimitAPI limitAPI = mock(LimitAPI.class);
when(limitAPI.getQuota(any(QuotaRequest.class))).thenAnswer(invocationOnMock -> {
@ -128,14 +105,24 @@ public class QuotaCheckReactiveFilterTest {
polarisRateLimitProperties.setRejectRequestTips("RejectRequestTips提示消息");
polarisRateLimitProperties.setRejectHttpCode(419);
RateLimitRuleLabelResolver rateLimitRuleLabelResolver = mock(RateLimitRuleLabelResolver.class);
when(rateLimitRuleLabelResolver.getExpressionLabelKeys(anyString(), anyString()))
.thenReturn(Collections.emptySet());
PolarisRateLimitProperties polarisRateLimitWithHtmlRejectTipsProperties = new PolarisRateLimitProperties();
polarisRateLimitWithHtmlRejectTipsProperties.setRejectRequestTips("<h1>RejectRequestTips提示消息</h1>");
polarisRateLimitWithHtmlRejectTipsProperties.setRejectHttpCode(419);
this.quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(
limitAPI, labelResolver, polarisRateLimitProperties, rateLimitRuleLabelResolver, null);
this.quotaCheckWithRateLimiterLimitedFallbackReactiveFilter = new QuotaCheckReactiveFilter(
limitAPI, labelResolver, polarisRateLimitProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentReactiveResolver = new RateLimitRuleArgumentReactiveResolver(serviceRuleManager, labelResolver);
this.quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentReactiveResolver, null);
this.polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
this.quotaCheckWithRateLimiterLimitedFallbackReactiveFilter = new QuotaCheckReactiveFilter(limitAPI, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleArgumentReactiveResolver, polarisRateLimiterLimitedFallback);
}
@Test
@ -156,42 +143,6 @@ public class QuotaCheckReactiveFilterTest {
}
}
@Test
public void testGetRuleExpressionLabels() {
try {
Method getCustomResolvedLabels =
QuotaCheckReactiveFilter.class.getDeclaredMethod("getCustomResolvedLabels", ServerWebExchange.class);
getCustomResolvedLabels.setAccessible(true);
// Mock request
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost:8080/test").build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
// labelResolver != null
Map<String, String> result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckReactiveFilter, exchange);
assertThat(result.size()).isEqualTo(1);
assertThat(result.get("ReactiveResolver")).isEqualTo("ReactiveResolver");
// throw exception
PolarisRateLimiterLabelReactiveResolver exceptionLabelResolver = exchange1 -> {
throw new RuntimeException("Mock exception.");
};
quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(null, exceptionLabelResolver, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckReactiveFilter, exchange);
assertThat(result.size()).isEqualTo(0);
// labelResolver == null
quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(null, null, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckReactiveFilter, exchange);
assertThat(result.size()).isEqualTo(0);
getCustomResolvedLabels.setAccessible(false);
}
catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
fail("Exception encountered.", e);
}
}
@Test
public void testFilter() {
// Create mock WebFilterChain
@ -199,19 +150,20 @@ public class QuotaCheckReactiveFilterTest {
// Mock request
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost:8080/test").build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.init();
// Pass
MetadataContext.LOCAL_SERVICE = "TestApp1";
quotaCheckReactiveFilter.filter(exchange, webFilterChain);
ServerWebExchange testApp1Exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.filter(testApp1Exchange, webFilterChain);
// Unirate waiting 1000ms
MetadataContext.LOCAL_SERVICE = "TestApp2";
ServerWebExchange testApp2Exchange = MockServerWebExchange.from(request);
long startTimestamp = System.currentTimeMillis();
CountDownLatch countDownLatch = new CountDownLatch(1);
quotaCheckReactiveFilter.filter(exchange, webFilterChain).subscribe(e -> {
quotaCheckReactiveFilter.filter(testApp2Exchange, webFilterChain).subscribe(e -> {
}, t -> {
}, countDownLatch::countDown);
try {
@ -224,14 +176,16 @@ public class QuotaCheckReactiveFilterTest {
// Rate limited
MetadataContext.LOCAL_SERVICE = "TestApp3";
quotaCheckReactiveFilter.filter(exchange, webFilterChain);
ServerHttpResponse response = exchange.getResponse();
ServerWebExchange testApp3Exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.filter(testApp3Exchange, webFilterChain);
ServerHttpResponse response = testApp3Exchange.getResponse();
assertThat(response.getRawStatusCode()).isEqualTo(419);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.INSUFFICIENT_SPACE_ON_RESOURCE);
// Exception
MetadataContext.LOCAL_SERVICE = "TestApp4";
quotaCheckReactiveFilter.filter(exchange, webFilterChain);
ServerWebExchange testApp4Exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.filter(testApp4Exchange, webFilterChain);
}
@Test

@ -17,53 +17,47 @@
package com.tencent.cloud.polaris.ratelimit.filter;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.stream.Collectors;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.ApplicationContextAwareUtils;
import com.tencent.cloud.common.util.expresstion.SpringWebExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
/**
@ -71,41 +65,23 @@ import static org.mockito.Mockito.when;
*
* @author Haotian Zhang, cheese8
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
@SpringBootTest(classes = QuotaCheckServletFilterTest.TestApplication.class, properties = {
@ExtendWith(SpringExtension.class)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
classes = QuotaCheckServletFilterTest.TestApplication.class,
properties = {
"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"
})
public class QuotaCheckServletFilterTest {
private static MockedStatic<ApplicationContextAwareUtils> mockedApplicationContextAwareUtils;
private static MockedStatic<SpringWebExpressionLabelUtils> expressionLabelUtilsMockedStatic;
private final PolarisRateLimiterLabelServletResolver labelResolver =
exchange -> Collections.singletonMap("ServletResolver", "ServletResolver");
exchange -> Collections.singletonMap("xxx", "xxx");
private QuotaCheckServletFilter quotaCheckServletFilter;
private QuotaCheckServletFilter quotaCheckWithHtmlRejectTipsServletFilter;
private QuotaCheckServletFilter quotaCheckWithRateLimiterLimitedFallbackFilter;
private PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
@BeforeAll
static void beforeAll() {
expressionLabelUtilsMockedStatic = mockStatic(SpringWebExpressionLabelUtils.class);
when(SpringWebExpressionLabelUtils.resolve(any(ServerWebExchange.class), anySet()))
.thenReturn(Collections.singletonMap("RuleLabelResolver", "RuleLabelResolver"));
mockedApplicationContextAwareUtils = Mockito.mockStatic(ApplicationContextAwareUtils.class);
mockedApplicationContextAwareUtils.when(() -> ApplicationContextAwareUtils.getProperties(anyString()))
.thenReturn("unit-test");
}
@AfterAll
static void afterAll() {
mockedApplicationContextAwareUtils.close();
expressionLabelUtilsMockedStatic.close();
}
@BeforeEach
void setUp() {
void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST";
LimitAPI limitAPI = mock(LimitAPI.class);
@ -133,15 +109,21 @@ public class QuotaCheckServletFilterTest {
polarisRateLimitWithHtmlRejectTipsProperties.setRejectRequestTips("<h1>RejectRequestTips提示消息</h1>");
polarisRateLimitWithHtmlRejectTipsProperties.setRejectHttpCode(419);
RateLimitRuleLabelResolver rateLimitRuleLabelResolver = mock(RateLimitRuleLabelResolver.class);
when(rateLimitRuleLabelResolver.getExpressionLabelKeys(anyString(), anyString())).thenReturn(Collections.emptySet());
this.quotaCheckServletFilter = new QuotaCheckServletFilter(limitAPI, labelResolver, polarisRateLimitProperties, rateLimitRuleLabelResolver, null);
this.quotaCheckWithHtmlRejectTipsServletFilter = new QuotaCheckServletFilter(
limitAPI, labelResolver, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleLabelResolver, null);
polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
this.quotaCheckWithRateLimiterLimitedFallbackFilter = new QuotaCheckServletFilter(
limitAPI, labelResolver, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
RateLimitRuleArgumentServletResolver rateLimitRuleArgumentServletResolver = new RateLimitRuleArgumentServletResolver(serviceRuleManager, labelResolver);
this.quotaCheckServletFilter = new QuotaCheckServletFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentServletResolver, null);
this.quotaCheckWithHtmlRejectTipsServletFilter = new QuotaCheckServletFilter(limitAPI, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleArgumentServletResolver, null);
this.polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
this.quotaCheckWithRateLimiterLimitedFallbackFilter = new QuotaCheckServletFilter(limitAPI, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleArgumentServletResolver, polarisRateLimiterLimitedFallback);
}
@Test
@ -167,40 +149,6 @@ public class QuotaCheckServletFilterTest {
quotaCheckWithRateLimiterLimitedFallbackFilter.init();
}
@Test
public void testGetRuleExpressionLabels() {
try {
Method getCustomResolvedLabels = QuotaCheckServletFilter.class.getDeclaredMethod("getCustomResolvedLabels", HttpServletRequest.class);
getCustomResolvedLabels.setAccessible(true);
// Mock request
MockHttpServletRequest request = new MockHttpServletRequest();
// labelResolver != null
Map<String, String> result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckServletFilter, request);
assertThat(result.size()).isEqualTo(1);
assertThat(result.get("ServletResolver")).isEqualTo("ServletResolver");
// throw exception
PolarisRateLimiterLabelServletResolver exceptionLabelResolver = request1 -> {
throw new RuntimeException("Mock exception.");
};
quotaCheckServletFilter = new QuotaCheckServletFilter(null, exceptionLabelResolver, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckServletFilter, request);
assertThat(result.size()).isEqualTo(0);
// labelResolver == null
quotaCheckServletFilter = new QuotaCheckServletFilter(null, null, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckServletFilter, request);
assertThat(result.size()).isEqualTo(0);
getCustomResolvedLabels.setAccessible(false);
}
catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
fail("Exception encountered.", e);
}
}
@Test
public void testDoFilterInternal() {
// Create mock FilterChain
@ -210,34 +158,39 @@ public class QuotaCheckServletFilterTest {
// Mock request
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
quotaCheckServletFilter.init();
quotaCheckWithHtmlRejectTipsServletFilter.init();
try {
// Pass
MetadataContext.LOCAL_SERVICE = "TestApp1";
quotaCheckServletFilter.doFilterInternal(request, response, filterChain);
MockHttpServletResponse testApp1Response = new MockHttpServletResponse();
quotaCheckServletFilter.doFilterInternal(request, testApp1Response, filterChain);
// Unirate waiting 1000ms
MetadataContext.LOCAL_SERVICE = "TestApp2";
MockHttpServletResponse testApp2Response = new MockHttpServletResponse();
long startTimestamp = System.currentTimeMillis();
quotaCheckServletFilter.doFilterInternal(request, response, filterChain);
quotaCheckServletFilter.doFilterInternal(request, testApp2Response, filterChain);
assertThat(System.currentTimeMillis() - startTimestamp).isGreaterThanOrEqualTo(1000L);
// Rate limited
MetadataContext.LOCAL_SERVICE = "TestApp3";
quotaCheckServletFilter.doFilterInternal(request, response, filterChain);
assertThat(response.getStatus()).isEqualTo(419);
assertThat(response.getContentAsString()).isEqualTo("RejectRequestTips提示消息");
MockHttpServletResponse testApp3Response = new MockHttpServletResponse();
quotaCheckServletFilter.doFilterInternal(request, testApp3Response, filterChain);
assertThat(testApp3Response.getStatus()).isEqualTo(419);
assertThat(testApp3Response.getContentAsString()).isEqualTo("RejectRequestTips提示消息");
quotaCheckWithHtmlRejectTipsServletFilter.doFilterInternal(request, response, filterChain);
assertThat(response.getStatus()).isEqualTo(419);
assertThat(response.getContentAsString()).isEqualTo("RejectRequestTips提示消息");
MockHttpServletResponse testApp3Response2 = new MockHttpServletResponse();
quotaCheckWithHtmlRejectTipsServletFilter.doFilterInternal(request, testApp3Response2, filterChain);
assertThat(testApp3Response2.getStatus()).isEqualTo(419);
assertThat(testApp3Response2.getContentAsString()).isEqualTo("<h1>RejectRequestTips提示消息</h1>");
// Exception
MockHttpServletResponse testApp4Response = new MockHttpServletResponse();
MetadataContext.LOCAL_SERVICE = "TestApp4";
quotaCheckServletFilter.doFilterInternal(request, response, filterChain);
quotaCheckServletFilter.doFilterInternal(request, testApp4Response, filterChain);
}
catch (ServletException | IOException e) {
fail("Exception encountered.", e);
@ -252,31 +205,34 @@ public class QuotaCheckServletFilterTest {
// Mock request
MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
quotaCheckWithRateLimiterLimitedFallbackFilter.init();
try {
// Pass
MetadataContext.LOCAL_SERVICE = "TestApp1";
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain);
MockHttpServletResponse testApp1response = new MockHttpServletResponse();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp1response, filterChain);
// Unirate waiting 1000ms
MetadataContext.LOCAL_SERVICE = "TestApp2";
MockHttpServletResponse testApp2response = new MockHttpServletResponse();
long startTimestamp = System.currentTimeMillis();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain);
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp2response, filterChain);
assertThat(System.currentTimeMillis() - startTimestamp).isGreaterThanOrEqualTo(1000L);
// Rate limited
MetadataContext.LOCAL_SERVICE = "TestApp3";
MockHttpServletResponse testApp3response = new MockHttpServletResponse();
String contentType = new MediaType(polarisRateLimiterLimitedFallback.mediaType(), polarisRateLimiterLimitedFallback.charset()).toString();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain);
assertThat(response.getStatus()).isEqualTo(polarisRateLimiterLimitedFallback.rejectHttpCode());
assertThat(response.getContentAsString()).isEqualTo(polarisRateLimiterLimitedFallback.rejectTips());
assertThat(response.getContentType()).isEqualTo(contentType);
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp3response, filterChain);
assertThat(testApp3response.getStatus()).isEqualTo(polarisRateLimiterLimitedFallback.rejectHttpCode());
assertThat(testApp3response.getContentAsString()).isEqualTo(polarisRateLimiterLimitedFallback.rejectTips());
assertThat(testApp3response.getContentType()).isEqualTo(contentType);
// Exception
MetadataContext.LOCAL_SERVICE = "TestApp4";
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain);
MockHttpServletResponse testApp4response = new MockHttpServletResponse();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp4response, filterChain);
}
catch (ServletException | IOException e) {
fail("Exception encountered.", e);

@ -0,0 +1,109 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilterTest;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ExtendWith(SpringExtension.class)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
classes = RateLimitRuleArgumentReactiveResolverTest.TestApplication.class,
properties = {
"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"
})
public class RateLimitRuleArgumentReactiveResolverTest {
private final PolarisRateLimiterLabelReactiveResolver labelResolver =
exchange -> Collections.singletonMap("xxx", "xxx");
private RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentReactiveResolver;
@BeforeEach
void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST";
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
this.rateLimitRuleArgumentReactiveResolver = new RateLimitRuleArgumentReactiveResolver(serviceRuleManager, labelResolver);
}
@Test
public void testGetRuleArguments() {
// Mock request
MetadataContext.LOCAL_SERVICE = "Test";
// Mock request
MockServerHttpRequest request = MockServerHttpRequest.get("http://127.0.0.1:8080/test")
.remoteAddress(new InetSocketAddress("127.0.0.1", 8080))
.header("xxx", "xxx")
.queryParam("yyy", "yyy")
.build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
Set<Argument> arguments = rateLimitRuleArgumentReactiveResolver.getArguments(exchange, MetadataContext.LOCAL_NAMESPACE, MetadataContext.LOCAL_SERVICE);
Set<Argument> exceptRes = new HashSet<>();
exceptRes.add(Argument.buildMethod("GET"));
exceptRes.add(Argument.buildHeader("xxx", "xxx"));
exceptRes.add(Argument.buildQuery("yyy", "yyy"));
exceptRes.add(Argument.buildCallerIP("127.0.0.1"));
exceptRes.add(Argument.buildCustom("xxx", "xxx"));
assertThat(arguments).isEqualTo(exceptRes);
}
@SpringBootApplication
protected static class TestApplication {
}
}

@ -0,0 +1,100 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilterTest;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ExtendWith(SpringExtension.class)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
classes = RateLimitRuleArgumentServletResolverTest.TestApplication.class,
properties = {
"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"
})
public class RateLimitRuleArgumentServletResolverTest {
private final PolarisRateLimiterLabelServletResolver labelResolver =
exchange -> Collections.singletonMap("xxx", "xxx");
private RateLimitRuleArgumentServletResolver rateLimitRuleArgumentServletResolver;
@BeforeEach
void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST";
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
this.rateLimitRuleArgumentServletResolver = new RateLimitRuleArgumentServletResolver(serviceRuleManager, labelResolver);
}
@Test
public void testGetRuleArguments() {
// Mock request
MetadataContext.LOCAL_SERVICE = "Test";
MockHttpServletRequest request = new MockHttpServletRequest(null, "GET", "/xxx");
request.setParameter("yyy", "yyy");
request.addHeader("xxx", "xxx");
Set<Argument> arguments = rateLimitRuleArgumentServletResolver.getArguments(request, MetadataContext.LOCAL_NAMESPACE, MetadataContext.LOCAL_SERVICE);
Set<Argument> exceptRes = new HashSet<>();
exceptRes.add(Argument.buildMethod("GET"));
exceptRes.add(Argument.buildHeader("xxx", "xxx"));
exceptRes.add(Argument.buildQuery("yyy", "yyy"));
exceptRes.add(Argument.buildCallerIP("127.0.0.1"));
exceptRes.add(Argument.buildCustom("xxx", "xxx"));
assertThat(arguments).isEqualTo(exceptRes);
}
@SpringBootApplication
protected static class TestApplication {
}
}

@ -17,6 +17,8 @@
package com.tencent.cloud.polaris.ratelimit.utils;
import java.util.HashSet;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
@ -66,28 +68,28 @@ public class QuotaCheckUtilsTest {
public void testGetQuota() {
// Pass
String serviceName = "TestApp1";
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null);
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk);
assertThat(quotaResponse.getWaitMs()).isEqualTo(0);
assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultOk");
// Unirate waiting 1000ms
serviceName = "TestApp2";
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null);
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk);
assertThat(quotaResponse.getWaitMs()).isEqualTo(1000);
assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultOk");
// Rate limited
serviceName = "TestApp3";
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null);
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultLimited);
assertThat(quotaResponse.getWaitMs()).isEqualTo(0);
assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultLimited");
// Exception
serviceName = "TestApp4";
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null);
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk);
assertThat(quotaResponse.getWaitMs()).isEqualTo(0);
assertThat(quotaResponse.getInfo()).isEqualTo("get quota failed");

@ -0,0 +1,95 @@
{
"id": "f7560cce829e4d4c8556a6be63539af5",
"service": "xxx",
"namespace": "default",
"subset": {},
"priority": 0,
"resource": "QPS",
"type": "GLOBAL",
"labels": {},
"amounts": [
{
"maxAmount": 1,
"validDuration": "1s",
"precision": null,
"startAmount": null,
"minAmount": null
}
],
"action": "REJECT",
"disable": false,
"report": null,
"ctime": "2022-12-11 21:56:59",
"mtime": "2023-03-10 15:40:33",
"revision": "6eec6f416bee40ecbf664c93add61358",
"service_token": null,
"adjuster": null,
"regex_combine": true,
"amountMode": "GLOBAL_TOTAL",
"failover": "FAILOVER_LOCAL",
"cluster": null,
"method": {
"type": "EXACT",
"value": "/xxx",
"value_type": "TEXT"
},
"arguments": [
{
"type": "CALLER_SERVICE",
"key": "default",
"value": {
"type": "EXACT",
"value": "xxx",
"value_type": "TEXT"
}
},
{
"type": "HEADER",
"key": "xxx",
"value": {
"type": "EXACT",
"value": "xxx",
"value_type": "TEXT"
}
},
{
"type": "QUERY",
"key": "yyy",
"value": {
"type": "EXACT",
"value": "yyy",
"value_type": "TEXT"
}
},
{
"type": "METHOD",
"key": "$method",
"value": {
"type": "EXACT",
"value": "GET",
"value_type": "TEXT"
}
},
{
"type": "CALLER_IP",
"key": "$caller_ip",
"value": {
"type": "EXACT",
"value": "127.0.0.1",
"value_type": "TEXT"
}
},
{
"type": "CUSTOM",
"key": "xxx",
"value": {
"type": "EXACT",
"value": "xxx",
"value_type": "TEXT"
}
}
],
"name": "xxx",
"etime": "2023-03-10 15:40:33",
"max_queue_delay": 30
}

@ -55,7 +55,7 @@ public final class MetadataConstant {
/**
* Order of filter.
*/
public static final int WEB_FILTER_ORDER = Ordered.HIGHEST_PRECEDENCE + 13;
public static final int WEB_FILTER_ORDER = Ordered.HIGHEST_PRECEDENCE + 9;
/**
* Order of MetadataFirstFeignInterceptor.
@ -83,6 +83,11 @@ public final class MetadataConstant {
*/
public static final String CUSTOM_DISPOSABLE_METADATA = "SCT-CUSTOM-DISPOSABLE-METADATA";
/**
* Custom Disposable Metadata.
*/
public static final String DEFAULT_DISPOSABLE_METADATA = "SCT-DEFAULT-DISPOSABLE-METADATA";
/**
* System Metadata.
*/
@ -93,4 +98,19 @@ public final class MetadataConstant {
*/
public static final String METADATA_CONTEXT = "SCT-METADATA-CONTEXT";
}
public static class DefaultMetadata {
/**
* Default Metadata Source Service Namespace Key.
*/
public static final String DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE = "source_service_namespace";
/**
* Default Metadata Source Service Name Key.
*/
public static final String DEFAULT_METADATA_SOURCE_SERVICE_NAME = "source_service_name";
}
}

@ -121,13 +121,13 @@ public final class MetadataContextHolder {
/**
* Save metadata map to thread local.
* @param dynamicTransitiveMetadata custom metadata collection
* @param dynamicDisposableMetadata custom disposable metadata connection
* @param dynamicCustomDisposableMetadata custom disposable metadata collection
* @param dynamicDefaultDisposableMetadata default disposable metadata collection
*/
public static void init(Map<String, String> dynamicTransitiveMetadata, Map<String, String> dynamicDisposableMetadata) {
public static void init(Map<String, String> dynamicTransitiveMetadata, Map<String, String> dynamicCustomDisposableMetadata, Map<String, String> dynamicDefaultDisposableMetadata) {
// Init ThreadLocal.
MetadataContextHolder.remove();
MetadataContext metadataContext = MetadataContextHolder.get();
// Save transitive metadata to ThreadLocal.
if (!CollectionUtils.isEmpty(dynamicTransitiveMetadata)) {
Map<String, String> staticTransitiveMetadata = metadataContext.getTransitiveMetadata();
@ -135,13 +135,20 @@ public final class MetadataContextHolder {
mergedTransitiveMetadata.putAll(staticTransitiveMetadata);
mergedTransitiveMetadata.putAll(dynamicTransitiveMetadata);
metadataContext.setTransitiveMetadata(Collections.unmodifiableMap(mergedTransitiveMetadata));
Map<String, String> mergedDisposableMetadata = new HashMap<>(dynamicDisposableMetadata);
metadataContext.setUpstreamDisposableMetadata(Collections.unmodifiableMap(mergedDisposableMetadata));
Map<String, String> staticDisposableMetadata = metadataContext.getFragmentContext(FRAGMENT_DISPOSABLE);
metadataContext.setDisposableMetadata(Collections.unmodifiableMap(staticDisposableMetadata));
}
if (!CollectionUtils.isEmpty(dynamicCustomDisposableMetadata) || !CollectionUtils.isEmpty(dynamicDefaultDisposableMetadata)) {
Map<String, String> mergedUpstreamDisposableMetadata = new HashMap<>();
if (!CollectionUtils.isEmpty(dynamicCustomDisposableMetadata)) {
mergedUpstreamDisposableMetadata.putAll(dynamicCustomDisposableMetadata);
}
if (!CollectionUtils.isEmpty(dynamicDefaultDisposableMetadata)) {
mergedUpstreamDisposableMetadata.putAll(dynamicDefaultDisposableMetadata);
}
metadataContext.setUpstreamDisposableMetadata(Collections.unmodifiableMap(mergedUpstreamDisposableMetadata));
}
Map<String, String> staticDisposableMetadata = metadataContext.getFragmentContext(FRAGMENT_DISPOSABLE);
metadataContext.setDisposableMetadata(Collections.unmodifiableMap(staticDisposableMetadata));
MetadataContextHolder.set(metadataContext);
}

@ -58,7 +58,7 @@ public class MetadataContextHolderTest {
customMetadata.put("a", "1");
customMetadata.put("b", "22");
customMetadata.put("c", "3");
MetadataContextHolder.init(customMetadata, new HashMap<>());
MetadataContextHolder.init(customMetadata, new HashMap<>(), new HashMap<>());
metadataContext = MetadataContextHolder.get();
customMetadata = metadataContext.getTransitiveMetadata();
Assertions.assertThat(customMetadata.get("a")).isEqualTo("1");

Loading…
Cancel
Save