use arguments instead of label in ratelimit and add caller info in ratelimit

pull/902/head
seanyu 3 years ago
parent 2d9afcc4fb
commit 22ece0aba5

@ -41,6 +41,7 @@ import org.springframework.web.server.WebFilterChain;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8; import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
/** /**
* Filter used for storing the metadata from upstream temporarily when web application is * Filter used for storing the metadata from upstream temporarily when web application is
@ -69,10 +70,10 @@ public class DecodeTransferMetadataReactiveFilter implements WebFilter, Ordered
mergedTransitiveMetadata.putAll(internalTransitiveMetadata); mergedTransitiveMetadata.putAll(internalTransitiveMetadata);
mergedTransitiveMetadata.putAll(customTransitiveMetadata); mergedTransitiveMetadata.putAll(customTransitiveMetadata);
Map<String, String> internalDisposableMetadata = getIntervalMetadata(serverHttpRequest, CUSTOM_DISPOSABLE_METADATA); Map<String, String> internalCustomDisposableMetadata = getIntervalMetadata(serverHttpRequest, CUSTOM_DISPOSABLE_METADATA);
Map<String, String> mergedDisposableMetadata = new HashMap<>(internalDisposableMetadata); Map<String, String> internalDefaultDisposableMetadata = getIntervalMetadata(serverHttpRequest, DEFAULT_DISPOSABLE_METADATA);
MetadataContextHolder.init(mergedTransitiveMetadata, mergedDisposableMetadata); MetadataContextHolder.init(mergedTransitiveMetadata, internalCustomDisposableMetadata, internalDefaultDisposableMetadata);
// Save to ServerWebExchange. // Save to ServerWebExchange.
serverWebExchange.getAttributes().put( serverWebExchange.getAttributes().put(

@ -43,6 +43,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8; import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
/** /**
* Filter used for storing the metadata from upstream temporarily when web application is * Filter used for storing the metadata from upstream temporarily when web application is
@ -66,10 +67,10 @@ public class DecodeTransferMetadataServletFilter extends OncePerRequestFilter {
mergedTransitiveMetadata.putAll(internalTransitiveMetadata); mergedTransitiveMetadata.putAll(internalTransitiveMetadata);
mergedTransitiveMetadata.putAll(customTransitiveMetadata); mergedTransitiveMetadata.putAll(customTransitiveMetadata);
Map<String, String> internalDisposableMetadata = getInternalMetadata(httpServletRequest, CUSTOM_DISPOSABLE_METADATA); Map<String, String> internalCustomDisposableMetadata = getInternalMetadata(httpServletRequest, CUSTOM_DISPOSABLE_METADATA);
Map<String, String> mergedDisposableMetadata = new HashMap<>(internalDisposableMetadata); Map<String, String> internalDefaultDisposableMetadata = getInternalMetadata(httpServletRequest, DEFAULT_DISPOSABLE_METADATA);
MetadataContextHolder.init(mergedTransitiveMetadata, mergedDisposableMetadata); MetadataContextHolder.init(mergedTransitiveMetadata, internalCustomDisposableMetadata, internalDefaultDisposableMetadata);
TransHeadersTransfer.transfer(httpServletRequest); TransHeadersTransfer.transfer(httpServletRequest);

@ -25,6 +25,7 @@ import com.tencent.cloud.common.constant.MetadataConstant;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.metadata.MetadataContextHolder; import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.common.util.JacksonUtils; import com.tencent.cloud.common.util.JacksonUtils;
import com.tencent.cloud.metadata.util.DefaultTransferMedataUtils;
import feign.RequestInterceptor; import feign.RequestInterceptor;
import feign.RequestTemplate; import feign.RequestTemplate;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -37,6 +38,7 @@ import org.springframework.web.client.RestTemplate;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8; import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
import static java.net.URLEncoder.encode; import static java.net.URLEncoder.encode;
/** /**
@ -61,7 +63,12 @@ public class EncodeTransferMedataFeignInterceptor implements RequestInterceptor,
Map<String, String> customMetadata = metadataContext.getCustomMetadata(); Map<String, String> customMetadata = metadataContext.getCustomMetadata();
Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata(); Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata();
Map<String, String> transHeaders = metadataContext.getTransHeadersKV(); Map<String, String> transHeaders = metadataContext.getTransHeadersKV();
Map<String, String> defaultMetadata = DefaultTransferMedataUtils.getDefaultTransferMedata();
// build default disposable metadata request header
this.buildMetadataHeader(requestTemplate, defaultMetadata, DEFAULT_DISPOSABLE_METADATA);
// build custom disposable metadata request header
this.buildMetadataHeader(requestTemplate, disposableMetadata, CUSTOM_DISPOSABLE_METADATA); this.buildMetadataHeader(requestTemplate, disposableMetadata, CUSTOM_DISPOSABLE_METADATA);
// process custom metadata // process custom metadata

@ -27,6 +27,7 @@ import com.tencent.cloud.common.constant.MetadataConstant;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.metadata.MetadataContextHolder; import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.common.util.JacksonUtils; import com.tencent.cloud.common.util.JacksonUtils;
import com.tencent.cloud.metadata.util.DefaultTransferMedataUtils;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.http.HttpRequest; import org.springframework.http.HttpRequest;
@ -39,6 +40,7 @@ import org.springframework.util.CollectionUtils;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8; import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
/** /**
* Interceptor used for adding the metadata in http headers from context when web client * Interceptor used for adding the metadata in http headers from context when web client
@ -61,6 +63,10 @@ public class EncodeTransferMedataRestTemplateInterceptor implements ClientHttpRe
Map<String, String> customMetadata = metadataContext.getCustomMetadata(); Map<String, String> customMetadata = metadataContext.getCustomMetadata();
Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata(); Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata();
Map<String, String> transHeaders = metadataContext.getTransHeadersKV(); Map<String, String> transHeaders = metadataContext.getTransHeadersKV();
Map<String, String> defaultMetadata = DefaultTransferMedataUtils.getDefaultTransferMedata();
// build default disposable metadata request header
this.buildMetadataHeader(httpRequest, defaultMetadata, DEFAULT_DISPOSABLE_METADATA);
// build custom disposable metadata request header // build custom disposable metadata request header
this.buildMetadataHeader(httpRequest, disposableMetadata, CUSTOM_DISPOSABLE_METADATA); this.buildMetadataHeader(httpRequest, disposableMetadata, CUSTOM_DISPOSABLE_METADATA);

@ -26,6 +26,7 @@ import com.tencent.cloud.common.constant.MetadataConstant;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.metadata.MetadataContextHolder; import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.common.util.JacksonUtils; import com.tencent.cloud.common.util.JacksonUtils;
import com.tencent.cloud.metadata.util.DefaultTransferMedataUtils;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.cloud.gateway.filter.GatewayFilterChain; import org.springframework.cloud.gateway.filter.GatewayFilterChain;
@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange;
import static com.tencent.cloud.common.constant.ContextConstant.UTF_8; import static com.tencent.cloud.common.constant.ContextConstant.UTF_8;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_DISPOSABLE_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA; import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.CUSTOM_METADATA;
import static com.tencent.cloud.common.constant.MetadataConstant.HeaderName.DEFAULT_DISPOSABLE_METADATA;
import static org.springframework.cloud.gateway.filter.ReactiveLoadBalancerClientFilter.LOAD_BALANCER_CLIENT_FILTER_ORDER; import static org.springframework.cloud.gateway.filter.ReactiveLoadBalancerClientFilter.LOAD_BALANCER_CLIENT_FILTER_ORDER;
/** /**
@ -67,9 +69,11 @@ public class EncodeTransferMedataScgFilter implements GlobalFilter, Ordered {
Map<String, String> customMetadata = metadataContext.getCustomMetadata(); Map<String, String> customMetadata = metadataContext.getCustomMetadata();
Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata(); Map<String, String> disposableMetadata = metadataContext.getDisposableMetadata();
Map<String, String> defaultMetadata = DefaultTransferMedataUtils.getDefaultTransferMedata();
this.buildMetadataHeader(builder, customMetadata, CUSTOM_METADATA); this.buildMetadataHeader(builder, customMetadata, CUSTOM_METADATA);
this.buildMetadataHeader(builder, disposableMetadata, CUSTOM_DISPOSABLE_METADATA); this.buildMetadataHeader(builder, disposableMetadata, CUSTOM_DISPOSABLE_METADATA);
this.buildMetadataHeader(builder, defaultMetadata, DEFAULT_DISPOSABLE_METADATA);
TransHeadersTransfer.transfer(exchange.getRequest()); TransHeadersTransfer.transfer(exchange.getRequest());

@ -0,0 +1,23 @@
package com.tencent.cloud.metadata.util;
import java.util.HashMap;
import java.util.Map;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAME;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE;
import static com.tencent.cloud.common.metadata.MetadataContext.LOCAL_NAMESPACE;
import static com.tencent.cloud.common.metadata.MetadataContext.LOCAL_SERVICE;
public final class DefaultTransferMedataUtils {
private DefaultTransferMedataUtils() {
}
public static Map<String, String> getDefaultTransferMedata() {
return new HashMap<String, String>() {{
put(DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE, LOCAL_NAMESPACE);
put(DEFAULT_METADATA_SOURCE_SERVICE_NAME, LOCAL_SERVICE);
}};
}
}

@ -121,7 +121,7 @@ public class PolarisCircuitBreakerMockServerTest {
} }
}, t -> "fallback"); }, t -> "fallback");
resList.add(res); resList.add(res);
Utils.sleepUninterrupted(1000); Utils.sleepUninterrupted(2000);
} }
assertThat(resList).isEqualTo(Arrays.asList("invoke success", "fallback", "fallback", "fallback", "fallback")); assertThat(resList).isEqualTo(Arrays.asList("invoke success", "fallback", "fallback", "fallback", "fallback"));

@ -19,6 +19,11 @@
<groupId>com.tencent.cloud</groupId> <groupId>com.tencent.cloud</groupId>
<artifactId>spring-cloud-tencent-polaris-context</artifactId> <artifactId>spring-cloud-tencent-polaris-context</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.tencent.cloud</groupId>
<artifactId>spring-cloud-starter-tencent-metadata-transfer</artifactId>
</dependency>
<!-- Spring Cloud Tencent dependencies end --> <!-- Spring Cloud Tencent dependencies end -->
<!-- Polaris dependencies start --> <!-- Polaris dependencies start -->

@ -36,6 +36,7 @@ import org.springframework.util.CollectionUtils;
* *
*@author lepdou 2022-05-13 *@author lepdou 2022-05-13
*/ */
@Deprecated
public class RateLimitRuleLabelResolver { public class RateLimitRuleLabelResolver {
private final ServiceRuleManager serviceRuleManager; private final ServiceRuleManager serviceRuleManager;

@ -20,10 +20,10 @@ package com.tencent.cloud.polaris.ratelimit.config;
import com.tencent.cloud.polaris.context.ServiceRuleManager; import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.context.config.PolarisContextAutoConfiguration; import com.tencent.cloud.polaris.context.config.PolarisContextAutoConfiguration;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckReactiveFilter; import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckReactiveFilter;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter; import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
@ -31,6 +31,7 @@ import com.tencent.polaris.client.api.SDKContext;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.factory.LimitAPIFactory; import com.tencent.polaris.ratelimit.factory.LimitAPIFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.AutoConfigureAfter; import org.springframework.boot.autoconfigure.AutoConfigureAfter;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
@ -39,6 +40,7 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import static com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant.FILTER_ORDER;
import static com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter.QUOTA_FILTER_BEAN_NAME; import static com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter.QUOTA_FILTER_BEAN_NAME;
import static javax.servlet.DispatcherType.ASYNC; import static javax.servlet.DispatcherType.ASYNC;
import static javax.servlet.DispatcherType.ERROR; import static javax.servlet.DispatcherType.ERROR;
@ -62,11 +64,6 @@ public class PolarisRateLimitAutoConfiguration {
return LimitAPIFactory.createLimitAPIByContext(polarisContext); return LimitAPIFactory.createLimitAPIByContext(polarisContext);
} }
@Bean
public RateLimitRuleLabelResolver rateLimitRuleLabelService(ServiceRuleManager serviceRuleManager) {
return new RateLimitRuleLabelResolver(serviceRuleManager);
}
/** /**
* Create when web application type is SERVLET. * Create when web application type is SERVLET.
*/ */
@ -74,15 +71,19 @@ public class PolarisRateLimitAutoConfiguration {
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET) @ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.SERVLET)
protected static class QuotaCheckFilterConfig { protected static class QuotaCheckFilterConfig {
@Bean
public RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver(ServiceRuleManager serviceRuleManager,
@Autowired(required = false) PolarisRateLimiterLabelServletResolver labelResolver) {
return new RateLimitRuleArgumentServletResolver(serviceRuleManager, labelResolver);
}
@Bean @Bean
@ConditionalOnMissingBean @ConditionalOnMissingBean
public QuotaCheckServletFilter quotaCheckFilter(LimitAPI limitAPI, public QuotaCheckServletFilter quotaCheckFilter(LimitAPI limitAPI,
@Nullable PolarisRateLimiterLabelServletResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties, PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver, RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) { @Autowired(required = false) PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
return new QuotaCheckServletFilter(limitAPI, labelResolver, return new QuotaCheckServletFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentResolver, polarisRateLimiterLimitedFallback);
polarisRateLimitProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
} }
@Bean @Bean
@ -92,9 +93,11 @@ public class PolarisRateLimitAutoConfiguration {
quotaCheckServletFilter); quotaCheckServletFilter);
registrationBean.setDispatcherTypes(ASYNC, ERROR, FORWARD, INCLUDE, REQUEST); registrationBean.setDispatcherTypes(ASYNC, ERROR, FORWARD, INCLUDE, REQUEST);
registrationBean.setName(QUOTA_FILTER_BEAN_NAME); registrationBean.setName(QUOTA_FILTER_BEAN_NAME);
registrationBean.setOrder(RateLimitConstant.FILTER_ORDER); registrationBean.setOrder(FILTER_ORDER);
return registrationBean; return registrationBean;
} }
} }
/** /**
@ -104,14 +107,18 @@ public class PolarisRateLimitAutoConfiguration {
@ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.REACTIVE) @ConditionalOnWebApplication(type = ConditionalOnWebApplication.Type.REACTIVE)
protected static class MetadataReactiveFilterConfig { protected static class MetadataReactiveFilterConfig {
@Bean
public RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver(ServiceRuleManager serviceRuleManager,
@Autowired(required = false) PolarisRateLimiterLabelReactiveResolver labelResolver) {
return new RateLimitRuleArgumentReactiveResolver(serviceRuleManager, labelResolver);
}
@Bean @Bean
public QuotaCheckReactiveFilter quotaCheckReactiveFilter(LimitAPI limitAPI, public QuotaCheckReactiveFilter quotaCheckReactiveFilter(LimitAPI limitAPI,
@Nullable PolarisRateLimiterLabelReactiveResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties, PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver, RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) { @Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
return new QuotaCheckReactiveFilter(limitAPI, labelResolver, return new QuotaCheckReactiveFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentResolver, polarisRateLimiterLimitedFallback);
polarisRateLimitProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
} }
} }
} }

@ -20,27 +20,22 @@ package com.tencent.cloud.polaris.ratelimit.filter;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import javax.annotation.PostConstruct; import javax.annotation.PostConstruct;
import com.google.common.collect.Maps;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.expresstion.SpringWebExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties; import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant; import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver; import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.cloud.polaris.ratelimit.utils.QuotaCheckUtils; import com.tencent.cloud.polaris.ratelimit.utils.QuotaCheckUtils;
import com.tencent.cloud.polaris.ratelimit.utils.RateLimitUtils; import com.tencent.cloud.polaris.ratelimit.utils.RateLimitUtils;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse; import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResultCode; import com.tencent.polaris.ratelimit.api.rpc.QuotaResultCode;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
@ -54,8 +49,6 @@ import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebFilterChain;
import static com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant.LABEL_METHOD;
/** /**
* Reactive filter to check quota. * Reactive filter to check quota.
* *
@ -67,11 +60,9 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
private final LimitAPI limitAPI; private final LimitAPI limitAPI;
private final PolarisRateLimiterLabelReactiveResolver labelResolver;
private final PolarisRateLimitProperties polarisRateLimitProperties; private final PolarisRateLimitProperties polarisRateLimitProperties;
private final RateLimitRuleLabelResolver rateLimitRuleLabelResolver; private final RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver;
private final PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback; private final PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
@ -79,14 +70,12 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
private String rejectTips; private String rejectTips;
public QuotaCheckReactiveFilter(LimitAPI limitAPI, public QuotaCheckReactiveFilter(LimitAPI limitAPI,
PolarisRateLimiterLabelReactiveResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties, PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver, RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) { @Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
this.limitAPI = limitAPI; this.limitAPI = limitAPI;
this.labelResolver = labelResolver;
this.polarisRateLimitProperties = polarisRateLimitProperties; this.polarisRateLimitProperties = polarisRateLimitProperties;
this.rateLimitRuleLabelResolver = rateLimitRuleLabelResolver; this.rateLimitRuleArgumentResolver = rateLimitRuleArgumentResolver;
this.polarisRateLimiterLimitedFallback = polarisRateLimiterLimitedFallback; this.polarisRateLimiterLimitedFallback = polarisRateLimiterLimitedFallback;
} }
@ -105,12 +94,12 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
String localNamespace = MetadataContext.LOCAL_NAMESPACE; String localNamespace = MetadataContext.LOCAL_NAMESPACE;
String localService = MetadataContext.LOCAL_SERVICE; String localService = MetadataContext.LOCAL_SERVICE;
Map<String, String> labels = getRequestLabels(exchange, localNamespace, localService); Set<Argument> arguments = rateLimitRuleArgumentResolver.getArguments(exchange, localNamespace, localService);
long waitMs = -1; long waitMs = -1;
try { try {
String path = exchange.getRequest().getURI().getPath(); String path = exchange.getRequest().getURI().getPath();
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI, QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(
localNamespace, localService, 1, labels, path); limitAPI, localNamespace, localService, 1, arguments, path);
if (quotaResponse.getCode() == QuotaResultCode.QuotaResultLimited) { if (quotaResponse.getCode() == QuotaResultCode.QuotaResultLimited) {
ServerHttpResponse response = exchange.getResponse(); ServerHttpResponse response = exchange.getResponse();
@ -149,40 +138,4 @@ public class QuotaCheckReactiveFilter implements WebFilter, Ordered {
} }
} }
private Map<String, String> getRequestLabels(ServerWebExchange exchange, String localNamespace, String localService) {
Map<String, String> labels = new HashMap<>();
// add build in labels
String path = exchange.getRequest().getURI().getPath();
if (StringUtils.isNotBlank(path)) {
labels.put(LABEL_METHOD, path);
}
// add rule expression labels
Map<String, String> expressionLabels = getRuleExpressionLabels(exchange, localNamespace, localService);
labels.putAll(expressionLabels);
// add custom labels
Map<String, String> customResolvedLabels = getCustomResolvedLabels(exchange);
labels.putAll(customResolvedLabels);
return labels;
}
private Map<String, String> getCustomResolvedLabels(ServerWebExchange exchange) {
if (labelResolver != null) {
try {
return labelResolver.resolve(exchange);
}
catch (Throwable e) {
LOGGER.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Maps.newHashMap();
}
private Map<String, String> getRuleExpressionLabels(ServerWebExchange exchange, String namespace, String service) {
Set<String> expressionLabels = rateLimitRuleLabelResolver.getExpressionLabelKeys(namespace, service);
return SpringWebExpressionLabelUtils.resolve(exchange, expressionLabels);
}
} }

@ -19,9 +19,6 @@
package com.tencent.cloud.polaris.ratelimit.filter; package com.tencent.cloud.polaris.ratelimit.filter;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -32,18 +29,16 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.expresstion.ServletExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties; import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant; import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver; import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.cloud.polaris.ratelimit.utils.QuotaCheckUtils; import com.tencent.cloud.polaris.ratelimit.utils.QuotaCheckUtils;
import com.tencent.cloud.polaris.ratelimit.utils.RateLimitUtils; import com.tencent.cloud.polaris.ratelimit.utils.RateLimitUtils;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse; import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResultCode; import com.tencent.polaris.ratelimit.api.rpc.QuotaResultCode;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -53,8 +48,6 @@ import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import static com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant.LABEL_METHOD;
/** /**
* Servlet filter to check quota. * Servlet filter to check quota.
* *
@ -70,25 +63,21 @@ public class QuotaCheckServletFilter extends OncePerRequestFilter {
private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class); private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class);
private final LimitAPI limitAPI; private final LimitAPI limitAPI;
private final PolarisRateLimiterLabelServletResolver labelResolver;
private final PolarisRateLimitProperties polarisRateLimitProperties; private final PolarisRateLimitProperties polarisRateLimitProperties;
private final RateLimitRuleLabelResolver rateLimitRuleLabelResolver; private final RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver;
private final PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback; private final PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
private String rejectTips; private String rejectTips;
public QuotaCheckServletFilter(LimitAPI limitAPI, public QuotaCheckServletFilter(LimitAPI limitAPI,
PolarisRateLimiterLabelServletResolver labelResolver,
PolarisRateLimitProperties polarisRateLimitProperties, PolarisRateLimitProperties polarisRateLimitProperties,
RateLimitRuleLabelResolver rateLimitRuleLabelResolver, RateLimitRuleArgumentServletResolver rateLimitRuleArgumentResolver,
@Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) { @Nullable PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback) {
this.limitAPI = limitAPI; this.limitAPI = limitAPI;
this.labelResolver = labelResolver;
this.polarisRateLimitProperties = polarisRateLimitProperties; this.polarisRateLimitProperties = polarisRateLimitProperties;
this.rateLimitRuleLabelResolver = rateLimitRuleLabelResolver; this.rateLimitRuleArgumentResolver = rateLimitRuleArgumentResolver;
this.polarisRateLimiterLimitedFallback = polarisRateLimiterLimitedFallback; this.polarisRateLimiterLimitedFallback = polarisRateLimiterLimitedFallback;
} }
@ -104,11 +93,11 @@ public class QuotaCheckServletFilter extends OncePerRequestFilter {
String localNamespace = MetadataContext.LOCAL_NAMESPACE; String localNamespace = MetadataContext.LOCAL_NAMESPACE;
String localService = MetadataContext.LOCAL_SERVICE; String localService = MetadataContext.LOCAL_SERVICE;
Map<String, String> labels = getRequestLabels(request, localNamespace, localService); Set<Argument> arguments = rateLimitRuleArgumentResolver.getArguments(request, localNamespace, localService);
try { try {
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI, QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI,
localNamespace, localService, 1, labels, request.getRequestURI()); localNamespace, localService, 1, arguments, request.getRequestURI());
if (quotaResponse.getCode() == QuotaResultCode.QuotaResultLimited) { if (quotaResponse.getCode() == QuotaResultCode.QuotaResultLimited) {
if (!Objects.isNull(polarisRateLimiterLimitedFallback)) { if (!Objects.isNull(polarisRateLimiterLimitedFallback)) {
@ -140,40 +129,4 @@ public class QuotaCheckServletFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response); filterChain.doFilter(request, response);
} }
private Map<String, String> getRequestLabels(HttpServletRequest request, String localNamespace, String localService) {
Map<String, String> labels = new HashMap<>();
// add build in labels
String path = request.getRequestURI();
if (StringUtils.isNotBlank(path)) {
labels.put(LABEL_METHOD, path);
}
// add rule expression labels
Map<String, String> expressionLabels = getRuleExpressionLabels(request, localNamespace, localService);
labels.putAll(expressionLabels);
// add custom resolved labels
Map<String, String> customLabels = getCustomResolvedLabels(request);
labels.putAll(customLabels);
return labels;
}
private Map<String, String> getCustomResolvedLabels(HttpServletRequest request) {
if (labelResolver != null) {
try {
return labelResolver.resolve(request);
}
catch (Throwable e) {
LOG.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Collections.emptyMap();
}
private Map<String, String> getRuleExpressionLabels(HttpServletRequest request, String namespace, String service) {
Set<String> expressionLabels = rateLimitRuleLabelResolver.getExpressionLabelKeys(namespace, service);
return ServletExpressionLabelUtils.resolve(request, expressionLabels);
}
} }

@ -0,0 +1,124 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import org.springframework.web.server.ServerWebExchange;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAME;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE;
/**
* resolve arguments from rate limit rule for Reactive.
*
* @author seansyyu 2023-03-09
*/
public class RateLimitRuleArgumentReactiveResolver {
private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class);
private final ServiceRuleManager serviceRuleManager;
private final PolarisRateLimiterLabelReactiveResolver labelResolver;
public RateLimitRuleArgumentReactiveResolver(ServiceRuleManager serviceRuleManager, PolarisRateLimiterLabelReactiveResolver labelResolver) {
this.serviceRuleManager = serviceRuleManager;
this.labelResolver = labelResolver;
}
public Set<Argument> getArguments(ServerWebExchange request, String namespace, String service) {
RateLimitProto.RateLimit rateLimitRule = serviceRuleManager.getServiceRateLimitRule(namespace, service);
if (rateLimitRule == null) {
return Collections.emptySet();
}
List<RateLimitProto.Rule> rules = rateLimitRule.getRulesList();
if (CollectionUtils.isEmpty(rules)) {
return Collections.emptySet();
}
return rules.stream()
.flatMap(rule -> rule.getArgumentsList().stream())
.map(matchArgument -> {
String matchKey = matchArgument.getKey();
Argument argument = null;
switch (matchArgument.getType()) {
case CUSTOM:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildCustom(matchKey, Optional.ofNullable(getCustomResolvedLabels(request).get(matchKey)).orElse(StringUtils.EMPTY));
break;
case METHOD:
argument = Argument.buildMethod(request.getRequest().getMethodValue());
break;
case HEADER:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildHeader(matchKey, Optional.ofNullable(request.getRequest().getHeaders().getFirst(matchKey)).orElse(StringUtils.EMPTY));
break;
case QUERY:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildQuery(matchKey, Optional.ofNullable(request.getRequest().getQueryParams().getFirst(matchKey)).orElse(StringUtils.EMPTY));
break;
case CALLER_SERVICE:
String sourceServiceNamespace = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE, true).orElse(StringUtils.EMPTY);
String sourceServiceName = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAME, true).orElse(StringUtils.EMPTY);
if (!StringUtils.isEmpty(sourceServiceNamespace) && !StringUtils.isEmpty(sourceServiceName)) {
argument = Argument.buildCallerService(sourceServiceNamespace, sourceServiceName);
}
break;
case CALLER_IP:
InetSocketAddress remoteAddress = request.getRequest().getRemoteAddress();
argument = Argument.buildCallerIP(remoteAddress != null ? remoteAddress.getAddress().getHostAddress() : StringUtils.EMPTY);
break;
default:
break;
}
return argument;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
private Map<String, String> getCustomResolvedLabels(ServerWebExchange request) {
if (labelResolver != null) {
try {
return labelResolver.resolve(request);
}
catch (Throwable e) {
LOG.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Collections.emptyMap();
}
}

@ -0,0 +1,123 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import com.tencent.cloud.common.metadata.MetadataContextHolder;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAME;
import static com.tencent.cloud.common.constant.MetadataConstant.DefaultMetadata.DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE;
/**
* resolve arguments from rate limit rule for Servlet.
*
* @author seansyyu 2023-03-09
*/
public class RateLimitRuleArgumentServletResolver {
private static final Logger LOG = LoggerFactory.getLogger(QuotaCheckServletFilter.class);
private final ServiceRuleManager serviceRuleManager;
private final PolarisRateLimiterLabelServletResolver labelResolver;
public RateLimitRuleArgumentServletResolver(ServiceRuleManager serviceRuleManager, PolarisRateLimiterLabelServletResolver labelResolver) {
this.serviceRuleManager = serviceRuleManager;
this.labelResolver = labelResolver;
}
public Set<Argument> getArguments(HttpServletRequest request, String namespace, String service) {
RateLimitProto.RateLimit rateLimitRule = serviceRuleManager.getServiceRateLimitRule(namespace, service);
if (rateLimitRule == null) {
return Collections.emptySet();
}
List<RateLimitProto.Rule> rules = rateLimitRule.getRulesList();
if (CollectionUtils.isEmpty(rules)) {
return Collections.emptySet();
}
return rules.stream()
.flatMap(rule -> rule.getArgumentsList().stream())
.map(matchArgument -> {
String matchKey = matchArgument.getKey();
Argument argument = null;
switch (matchArgument.getType()) {
case CUSTOM:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildCustom(matchKey, Optional.ofNullable(getCustomResolvedLabels(request).get(matchKey)).orElse(StringUtils.EMPTY));
break;
case METHOD:
argument = Argument.buildMethod(request.getMethod());
break;
case HEADER:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildHeader(matchKey, Optional.ofNullable(request.getHeader(matchKey)).orElse(StringUtils.EMPTY));
break;
case QUERY:
argument = StringUtils.isBlank(matchKey) ? null :
Argument.buildQuery(matchKey, Optional.ofNullable(request.getParameter(matchKey)).orElse(StringUtils.EMPTY));
break;
case CALLER_SERVICE:
String sourceServiceNamespace = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE, true).orElse(StringUtils.EMPTY);
String sourceServiceName = MetadataContextHolder.getDisposableMetadata(DEFAULT_METADATA_SOURCE_SERVICE_NAME, true).orElse(StringUtils.EMPTY);
if (!StringUtils.isEmpty(sourceServiceNamespace) && !StringUtils.isEmpty(sourceServiceName)) {
argument = Argument.buildCallerService(sourceServiceNamespace, sourceServiceName);
}
break;
case CALLER_IP:
argument = Argument.buildCallerIP(Optional.ofNullable(request.getRemoteAddr()).orElse(StringUtils.EMPTY));
break;
default:
break;
}
return argument;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
private Map<String, String> getCustomResolvedLabels(HttpServletRequest request) {
if (labelResolver != null) {
try {
return labelResolver.resolve(request);
}
catch (Throwable e) {
LOG.error("resolve custom label failed. resolver = {}", labelResolver.getClass().getName(), e);
}
}
return Collections.emptyMap();
}
}

@ -18,9 +18,11 @@
package com.tencent.cloud.polaris.ratelimit.utils; package com.tencent.cloud.polaris.ratelimit.utils;
import java.util.Map; import java.util.Map;
import java.util.Set;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult; import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest; import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse; import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -38,6 +40,7 @@ public final class QuotaCheckUtils {
private QuotaCheckUtils() { private QuotaCheckUtils() {
} }
@Deprecated
public static QuotaResponse getQuota(LimitAPI limitAPI, String namespace, String service, int count, public static QuotaResponse getQuota(LimitAPI limitAPI, String namespace, String service, int count,
Map<String, String> labels, String method) { Map<String, String> labels, String method) {
// build quota request // build quota request
@ -56,4 +59,23 @@ public final class QuotaCheckUtils {
return new QuotaResponse(new QuotaResult(QuotaResult.Code.QuotaResultOk, 0, "get quota failed")); return new QuotaResponse(new QuotaResult(QuotaResult.Code.QuotaResultOk, 0, "get quota failed"));
} }
} }
public static QuotaResponse getQuota(LimitAPI limitAPI, String namespace, String service, int count,
Set<Argument> arguments, String method) {
// build quota request
QuotaRequest quotaRequest = new QuotaRequest();
quotaRequest.setNamespace(namespace);
quotaRequest.setService(service);
quotaRequest.setCount(count);
quotaRequest.setArguments(arguments);
quotaRequest.setMethod(method);
try {
return limitAPI.getQuota(quotaRequest);
}
catch (Throwable throwable) {
LOG.error("fail to invoke getQuota of LimitAPI with QuotaRequest[{}].", quotaRequest, throwable);
return new QuotaResponse(new QuotaResult(QuotaResult.Code.QuotaResultOk, 0, "get quota failed"));
}
}
} }

@ -18,9 +18,10 @@
package com.tencent.cloud.polaris.ratelimit.config; package com.tencent.cloud.polaris.ratelimit.config;
import com.tencent.cloud.polaris.context.config.PolarisContextAutoConfiguration; import com.tencent.cloud.polaris.context.config.PolarisContextAutoConfiguration;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckReactiveFilter; import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckReactiveFilter;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter; import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilter;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -55,7 +56,8 @@ public class PolarisRateLimitAutoConfigurationTest {
PolarisRateLimitAutoConfiguration.class)) PolarisRateLimitAutoConfiguration.class))
.run(context -> { .run(context -> {
assertThat(context).hasSingleBean(LimitAPI.class); assertThat(context).hasSingleBean(LimitAPI.class);
assertThat(context).hasSingleBean(RateLimitRuleLabelResolver.class); assertThat(context).doesNotHaveBean(RateLimitRuleArgumentServletResolver.class);
assertThat(context).doesNotHaveBean(RateLimitRuleArgumentReactiveResolver.class);
assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class); assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class);
assertThat(context).doesNotHaveBean(QuotaCheckServletFilter.class); assertThat(context).doesNotHaveBean(QuotaCheckServletFilter.class);
assertThat(context).doesNotHaveBean(FilterRegistrationBean.class); assertThat(context).doesNotHaveBean(FilterRegistrationBean.class);
@ -72,12 +74,13 @@ public class PolarisRateLimitAutoConfigurationTest {
PolarisRateLimitAutoConfiguration.class)) PolarisRateLimitAutoConfiguration.class))
.run(context -> { .run(context -> {
assertThat(context).hasSingleBean(LimitAPI.class); assertThat(context).hasSingleBean(LimitAPI.class);
assertThat(context).hasSingleBean(RateLimitRuleLabelResolver.class); assertThat(context).hasSingleBean(RateLimitRuleArgumentServletResolver.class);
assertThat(context).hasSingleBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class); assertThat(context).hasSingleBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class);
assertThat(context).hasSingleBean(QuotaCheckServletFilter.class); assertThat(context).hasSingleBean(QuotaCheckServletFilter.class);
assertThat(context).hasSingleBean(FilterRegistrationBean.class); assertThat(context).hasSingleBean(FilterRegistrationBean.class);
assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.MetadataReactiveFilterConfig.class); assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.MetadataReactiveFilterConfig.class);
assertThat(context).doesNotHaveBean(QuotaCheckReactiveFilter.class); assertThat(context).doesNotHaveBean(QuotaCheckReactiveFilter.class);
assertThat(context).doesNotHaveBean(RateLimitRuleArgumentReactiveResolver.class);
}); });
} }
@ -89,7 +92,8 @@ public class PolarisRateLimitAutoConfigurationTest {
PolarisRateLimitAutoConfiguration.class)) PolarisRateLimitAutoConfiguration.class))
.run(context -> { .run(context -> {
assertThat(context).hasSingleBean(LimitAPI.class); assertThat(context).hasSingleBean(LimitAPI.class);
assertThat(context).hasSingleBean(RateLimitRuleLabelResolver.class); assertThat(context).doesNotHaveBean(RateLimitRuleArgumentServletResolver.class);
assertThat(context).hasSingleBean(RateLimitRuleArgumentReactiveResolver.class);
assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class); assertThat(context).doesNotHaveBean(PolarisRateLimitAutoConfiguration.QuotaCheckFilterConfig.class);
assertThat(context).doesNotHaveBean(QuotaCheckServletFilter.class); assertThat(context).doesNotHaveBean(QuotaCheckServletFilter.class);
assertThat(context).doesNotHaveBean(FilterRegistrationBean.class); assertThat(context).doesNotHaveBean(FilterRegistrationBean.class);

@ -17,32 +17,32 @@
package com.tencent.cloud.polaris.ratelimit.filter; package com.tencent.cloud.polaris.ratelimit.filter;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException; import java.nio.charset.StandardCharsets;
import java.lang.reflect.Method;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.ApplicationContextAwareUtils; import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.common.util.expresstion.SpringWebExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties; import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant; import com.tencent.cloud.polaris.ratelimit.constant.RateLimitConstant;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult; import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest; import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse; import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import org.junit.jupiter.api.AfterAll; import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness; import org.mockito.quality.Strictness;
@ -60,10 +60,8 @@ import org.springframework.web.server.WebFilterChain;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@ -76,36 +74,15 @@ import static org.mockito.Mockito.when;
@SpringBootTest(classes = QuotaCheckReactiveFilterTest.TestApplication.class, @SpringBootTest(classes = QuotaCheckReactiveFilterTest.TestApplication.class,
properties = {"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"}) properties = {"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"})
public class QuotaCheckReactiveFilterTest { public class QuotaCheckReactiveFilterTest {
private static MockedStatic<ApplicationContextAwareUtils> mockedApplicationContextAwareUtils;
private static MockedStatic<SpringWebExpressionLabelUtils> expressionLabelUtilsMockedStatic;
private final PolarisRateLimiterLabelReactiveResolver labelResolver = private final PolarisRateLimiterLabelReactiveResolver labelResolver =
exchange -> Collections.singletonMap("ReactiveResolver", "ReactiveResolver"); exchange -> Collections.singletonMap("xxx", "xxx");
private QuotaCheckReactiveFilter quotaCheckReactiveFilter; private QuotaCheckReactiveFilter quotaCheckReactiveFilter;
private QuotaCheckReactiveFilter quotaCheckWithRateLimiterLimitedFallbackReactiveFilter; private QuotaCheckReactiveFilter quotaCheckWithRateLimiterLimitedFallbackReactiveFilter;
private PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback; private PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
@BeforeAll
static void beforeAll() {
expressionLabelUtilsMockedStatic = mockStatic(SpringWebExpressionLabelUtils.class);
when(SpringWebExpressionLabelUtils.resolve(any(ServerWebExchange.class), anySet()))
.thenReturn(Collections.singletonMap("RuleLabelResolver", "RuleLabelResolver"));
mockedApplicationContextAwareUtils = Mockito.mockStatic(ApplicationContextAwareUtils.class);
mockedApplicationContextAwareUtils.when(() -> ApplicationContextAwareUtils.getProperties(anyString()))
.thenReturn("unit-test");
}
@AfterAll
static void afterAll() {
mockedApplicationContextAwareUtils.close();
expressionLabelUtilsMockedStatic.close();
}
@BeforeEach @BeforeEach
void setUp() { void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST"; MetadataContext.LOCAL_NAMESPACE = "TEST";
polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
LimitAPI limitAPI = mock(LimitAPI.class); LimitAPI limitAPI = mock(LimitAPI.class);
when(limitAPI.getQuota(any(QuotaRequest.class))).thenAnswer(invocationOnMock -> { when(limitAPI.getQuota(any(QuotaRequest.class))).thenAnswer(invocationOnMock -> {
@ -128,14 +105,24 @@ public class QuotaCheckReactiveFilterTest {
polarisRateLimitProperties.setRejectRequestTips("RejectRequestTips提示消息"); polarisRateLimitProperties.setRejectRequestTips("RejectRequestTips提示消息");
polarisRateLimitProperties.setRejectHttpCode(419); polarisRateLimitProperties.setRejectHttpCode(419);
RateLimitRuleLabelResolver rateLimitRuleLabelResolver = mock(RateLimitRuleLabelResolver.class); PolarisRateLimitProperties polarisRateLimitWithHtmlRejectTipsProperties = new PolarisRateLimitProperties();
when(rateLimitRuleLabelResolver.getExpressionLabelKeys(anyString(), anyString())) polarisRateLimitWithHtmlRejectTipsProperties.setRejectRequestTips("<h1>RejectRequestTips提示消息</h1>");
.thenReturn(Collections.emptySet()); polarisRateLimitWithHtmlRejectTipsProperties.setRejectHttpCode(419);
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
this.quotaCheckReactiveFilter = new QuotaCheckReactiveFilter( RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
limitAPI, labelResolver, polarisRateLimitProperties, rateLimitRuleLabelResolver, null); InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
this.quotaCheckWithRateLimiterLimitedFallbackReactiveFilter = new QuotaCheckReactiveFilter( String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
limitAPI, labelResolver, polarisRateLimitProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback); JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentReactiveResolver = new RateLimitRuleArgumentReactiveResolver(serviceRuleManager, labelResolver);
this.quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentReactiveResolver, null);
this.polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
this.quotaCheckWithRateLimiterLimitedFallbackReactiveFilter = new QuotaCheckReactiveFilter(limitAPI, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleArgumentReactiveResolver, polarisRateLimiterLimitedFallback);
} }
@Test @Test
@ -156,42 +143,6 @@ public class QuotaCheckReactiveFilterTest {
} }
} }
@Test
public void testGetRuleExpressionLabels() {
try {
Method getCustomResolvedLabels =
QuotaCheckReactiveFilter.class.getDeclaredMethod("getCustomResolvedLabels", ServerWebExchange.class);
getCustomResolvedLabels.setAccessible(true);
// Mock request
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost:8080/test").build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
// labelResolver != null
Map<String, String> result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckReactiveFilter, exchange);
assertThat(result.size()).isEqualTo(1);
assertThat(result.get("ReactiveResolver")).isEqualTo("ReactiveResolver");
// throw exception
PolarisRateLimiterLabelReactiveResolver exceptionLabelResolver = exchange1 -> {
throw new RuntimeException("Mock exception.");
};
quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(null, exceptionLabelResolver, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckReactiveFilter, exchange);
assertThat(result.size()).isEqualTo(0);
// labelResolver == null
quotaCheckReactiveFilter = new QuotaCheckReactiveFilter(null, null, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckReactiveFilter, exchange);
assertThat(result.size()).isEqualTo(0);
getCustomResolvedLabels.setAccessible(false);
}
catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
fail("Exception encountered.", e);
}
}
@Test @Test
public void testFilter() { public void testFilter() {
// Create mock WebFilterChain // Create mock WebFilterChain
@ -199,19 +150,20 @@ public class QuotaCheckReactiveFilterTest {
// Mock request // Mock request
MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost:8080/test").build(); MockServerHttpRequest request = MockServerHttpRequest.get("http://localhost:8080/test").build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.init(); quotaCheckReactiveFilter.init();
// Pass // Pass
MetadataContext.LOCAL_SERVICE = "TestApp1"; MetadataContext.LOCAL_SERVICE = "TestApp1";
quotaCheckReactiveFilter.filter(exchange, webFilterChain); ServerWebExchange testApp1Exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.filter(testApp1Exchange, webFilterChain);
// Unirate waiting 1000ms // Unirate waiting 1000ms
MetadataContext.LOCAL_SERVICE = "TestApp2"; MetadataContext.LOCAL_SERVICE = "TestApp2";
ServerWebExchange testApp2Exchange = MockServerWebExchange.from(request);
long startTimestamp = System.currentTimeMillis(); long startTimestamp = System.currentTimeMillis();
CountDownLatch countDownLatch = new CountDownLatch(1); CountDownLatch countDownLatch = new CountDownLatch(1);
quotaCheckReactiveFilter.filter(exchange, webFilterChain).subscribe(e -> { quotaCheckReactiveFilter.filter(testApp2Exchange, webFilterChain).subscribe(e -> {
}, t -> { }, t -> {
}, countDownLatch::countDown); }, countDownLatch::countDown);
try { try {
@ -224,14 +176,16 @@ public class QuotaCheckReactiveFilterTest {
// Rate limited // Rate limited
MetadataContext.LOCAL_SERVICE = "TestApp3"; MetadataContext.LOCAL_SERVICE = "TestApp3";
quotaCheckReactiveFilter.filter(exchange, webFilterChain); ServerWebExchange testApp3Exchange = MockServerWebExchange.from(request);
ServerHttpResponse response = exchange.getResponse(); quotaCheckReactiveFilter.filter(testApp3Exchange, webFilterChain);
ServerHttpResponse response = testApp3Exchange.getResponse();
assertThat(response.getRawStatusCode()).isEqualTo(419); assertThat(response.getRawStatusCode()).isEqualTo(419);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.INSUFFICIENT_SPACE_ON_RESOURCE); assertThat(response.getStatusCode()).isEqualTo(HttpStatus.INSUFFICIENT_SPACE_ON_RESOURCE);
// Exception // Exception
MetadataContext.LOCAL_SERVICE = "TestApp4"; MetadataContext.LOCAL_SERVICE = "TestApp4";
quotaCheckReactiveFilter.filter(exchange, webFilterChain); ServerWebExchange testApp4Exchange = MockServerWebExchange.from(request);
quotaCheckReactiveFilter.filter(testApp4Exchange, webFilterChain);
} }
@Test @Test

@ -17,53 +17,47 @@
package com.tencent.cloud.polaris.ratelimit.filter; package com.tencent.cloud.polaris.ratelimit.filter;
import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException; import java.nio.charset.StandardCharsets;
import java.lang.reflect.Method;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.stream.Collectors;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext; import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.common.util.ApplicationContextAwareUtils; import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.common.util.expresstion.SpringWebExpressionLabelUtils;
import com.tencent.cloud.polaris.ratelimit.RateLimitRuleLabelResolver;
import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties; import com.tencent.cloud.polaris.ratelimit.config.PolarisRateLimitProperties;
import com.tencent.cloud.polaris.ratelimit.resolver.RateLimitRuleArgumentServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback; import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLimitedFallback;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult; import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest; import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse; import com.tencent.polaris.ratelimit.api.rpc.QuotaResponse;
import org.junit.jupiter.api.AfterAll; import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.web.server.ServerWebExchange; import org.springframework.test.context.junit.jupiter.SpringExtension;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/** /**
@ -71,41 +65,23 @@ import static org.mockito.Mockito.when;
* *
* @author Haotian Zhang, cheese8 * @author Haotian Zhang, cheese8
*/ */
@ExtendWith(MockitoExtension.class) @ExtendWith(SpringExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT) @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
@SpringBootTest(classes = QuotaCheckServletFilterTest.TestApplication.class, properties = { classes = QuotaCheckServletFilterTest.TestApplication.class,
properties = {
"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp" "spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"
}) })
public class QuotaCheckServletFilterTest { public class QuotaCheckServletFilterTest {
private static MockedStatic<ApplicationContextAwareUtils> mockedApplicationContextAwareUtils;
private static MockedStatic<SpringWebExpressionLabelUtils> expressionLabelUtilsMockedStatic;
private final PolarisRateLimiterLabelServletResolver labelResolver = private final PolarisRateLimiterLabelServletResolver labelResolver =
exchange -> Collections.singletonMap("ServletResolver", "ServletResolver"); exchange -> Collections.singletonMap("xxx", "xxx");
private QuotaCheckServletFilter quotaCheckServletFilter; private QuotaCheckServletFilter quotaCheckServletFilter;
private QuotaCheckServletFilter quotaCheckWithHtmlRejectTipsServletFilter; private QuotaCheckServletFilter quotaCheckWithHtmlRejectTipsServletFilter;
private QuotaCheckServletFilter quotaCheckWithRateLimiterLimitedFallbackFilter; private QuotaCheckServletFilter quotaCheckWithRateLimiterLimitedFallbackFilter;
private PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback; private PolarisRateLimiterLimitedFallback polarisRateLimiterLimitedFallback;
@BeforeAll
static void beforeAll() {
expressionLabelUtilsMockedStatic = mockStatic(SpringWebExpressionLabelUtils.class);
when(SpringWebExpressionLabelUtils.resolve(any(ServerWebExchange.class), anySet()))
.thenReturn(Collections.singletonMap("RuleLabelResolver", "RuleLabelResolver"));
mockedApplicationContextAwareUtils = Mockito.mockStatic(ApplicationContextAwareUtils.class);
mockedApplicationContextAwareUtils.when(() -> ApplicationContextAwareUtils.getProperties(anyString()))
.thenReturn("unit-test");
}
@AfterAll
static void afterAll() {
mockedApplicationContextAwareUtils.close();
expressionLabelUtilsMockedStatic.close();
}
@BeforeEach @BeforeEach
void setUp() { void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST"; MetadataContext.LOCAL_NAMESPACE = "TEST";
LimitAPI limitAPI = mock(LimitAPI.class); LimitAPI limitAPI = mock(LimitAPI.class);
@ -133,15 +109,21 @@ public class QuotaCheckServletFilterTest {
polarisRateLimitWithHtmlRejectTipsProperties.setRejectRequestTips("<h1>RejectRequestTips提示消息</h1>"); polarisRateLimitWithHtmlRejectTipsProperties.setRejectRequestTips("<h1>RejectRequestTips提示消息</h1>");
polarisRateLimitWithHtmlRejectTipsProperties.setRejectHttpCode(419); polarisRateLimitWithHtmlRejectTipsProperties.setRejectHttpCode(419);
RateLimitRuleLabelResolver rateLimitRuleLabelResolver = mock(RateLimitRuleLabelResolver.class); ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
when(rateLimitRuleLabelResolver.getExpressionLabelKeys(anyString(), anyString())).thenReturn(Collections.emptySet());
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
this.quotaCheckServletFilter = new QuotaCheckServletFilter(limitAPI, labelResolver, polarisRateLimitProperties, rateLimitRuleLabelResolver, null); RateLimitRuleArgumentServletResolver rateLimitRuleArgumentServletResolver = new RateLimitRuleArgumentServletResolver(serviceRuleManager, labelResolver);
this.quotaCheckWithHtmlRejectTipsServletFilter = new QuotaCheckServletFilter( this.quotaCheckServletFilter = new QuotaCheckServletFilter(limitAPI, polarisRateLimitProperties, rateLimitRuleArgumentServletResolver, null);
limitAPI, labelResolver, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleLabelResolver, null); this.quotaCheckWithHtmlRejectTipsServletFilter = new QuotaCheckServletFilter(limitAPI, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleArgumentServletResolver, null);
polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback(); this.polarisRateLimiterLimitedFallback = new JsonPolarisRateLimiterLimitedFallback();
this.quotaCheckWithRateLimiterLimitedFallbackFilter = new QuotaCheckServletFilter( this.quotaCheckWithRateLimiterLimitedFallbackFilter = new QuotaCheckServletFilter(limitAPI, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleArgumentServletResolver, polarisRateLimiterLimitedFallback);
limitAPI, labelResolver, polarisRateLimitWithHtmlRejectTipsProperties, rateLimitRuleLabelResolver, polarisRateLimiterLimitedFallback);
} }
@Test @Test
@ -167,40 +149,6 @@ public class QuotaCheckServletFilterTest {
quotaCheckWithRateLimiterLimitedFallbackFilter.init(); quotaCheckWithRateLimiterLimitedFallbackFilter.init();
} }
@Test
public void testGetRuleExpressionLabels() {
try {
Method getCustomResolvedLabels = QuotaCheckServletFilter.class.getDeclaredMethod("getCustomResolvedLabels", HttpServletRequest.class);
getCustomResolvedLabels.setAccessible(true);
// Mock request
MockHttpServletRequest request = new MockHttpServletRequest();
// labelResolver != null
Map<String, String> result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckServletFilter, request);
assertThat(result.size()).isEqualTo(1);
assertThat(result.get("ServletResolver")).isEqualTo("ServletResolver");
// throw exception
PolarisRateLimiterLabelServletResolver exceptionLabelResolver = request1 -> {
throw new RuntimeException("Mock exception.");
};
quotaCheckServletFilter = new QuotaCheckServletFilter(null, exceptionLabelResolver, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckServletFilter, request);
assertThat(result.size()).isEqualTo(0);
// labelResolver == null
quotaCheckServletFilter = new QuotaCheckServletFilter(null, null, null, null, null);
result = (Map<String, String>) getCustomResolvedLabels.invoke(quotaCheckServletFilter, request);
assertThat(result.size()).isEqualTo(0);
getCustomResolvedLabels.setAccessible(false);
}
catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
fail("Exception encountered.", e);
}
}
@Test @Test
public void testDoFilterInternal() { public void testDoFilterInternal() {
// Create mock FilterChain // Create mock FilterChain
@ -210,34 +158,39 @@ public class QuotaCheckServletFilterTest {
// Mock request // Mock request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
quotaCheckServletFilter.init(); quotaCheckServletFilter.init();
quotaCheckWithHtmlRejectTipsServletFilter.init();
try { try {
// Pass // Pass
MetadataContext.LOCAL_SERVICE = "TestApp1"; MetadataContext.LOCAL_SERVICE = "TestApp1";
quotaCheckServletFilter.doFilterInternal(request, response, filterChain); MockHttpServletResponse testApp1Response = new MockHttpServletResponse();
quotaCheckServletFilter.doFilterInternal(request, testApp1Response, filterChain);
// Unirate waiting 1000ms // Unirate waiting 1000ms
MetadataContext.LOCAL_SERVICE = "TestApp2"; MetadataContext.LOCAL_SERVICE = "TestApp2";
MockHttpServletResponse testApp2Response = new MockHttpServletResponse();
long startTimestamp = System.currentTimeMillis(); long startTimestamp = System.currentTimeMillis();
quotaCheckServletFilter.doFilterInternal(request, response, filterChain); quotaCheckServletFilter.doFilterInternal(request, testApp2Response, filterChain);
assertThat(System.currentTimeMillis() - startTimestamp).isGreaterThanOrEqualTo(1000L); assertThat(System.currentTimeMillis() - startTimestamp).isGreaterThanOrEqualTo(1000L);
// Rate limited // Rate limited
MetadataContext.LOCAL_SERVICE = "TestApp3"; MetadataContext.LOCAL_SERVICE = "TestApp3";
quotaCheckServletFilter.doFilterInternal(request, response, filterChain); MockHttpServletResponse testApp3Response = new MockHttpServletResponse();
assertThat(response.getStatus()).isEqualTo(419); quotaCheckServletFilter.doFilterInternal(request, testApp3Response, filterChain);
assertThat(response.getContentAsString()).isEqualTo("RejectRequestTips提示消息"); assertThat(testApp3Response.getStatus()).isEqualTo(419);
assertThat(testApp3Response.getContentAsString()).isEqualTo("RejectRequestTips提示消息");
quotaCheckWithHtmlRejectTipsServletFilter.doFilterInternal(request, response, filterChain); MockHttpServletResponse testApp3Response2 = new MockHttpServletResponse();
assertThat(response.getStatus()).isEqualTo(419); quotaCheckWithHtmlRejectTipsServletFilter.doFilterInternal(request, testApp3Response2, filterChain);
assertThat(response.getContentAsString()).isEqualTo("RejectRequestTips提示消息"); assertThat(testApp3Response2.getStatus()).isEqualTo(419);
assertThat(testApp3Response2.getContentAsString()).isEqualTo("<h1>RejectRequestTips提示消息</h1>");
// Exception // Exception
MockHttpServletResponse testApp4Response = new MockHttpServletResponse();
MetadataContext.LOCAL_SERVICE = "TestApp4"; MetadataContext.LOCAL_SERVICE = "TestApp4";
quotaCheckServletFilter.doFilterInternal(request, response, filterChain); quotaCheckServletFilter.doFilterInternal(request, testApp4Response, filterChain);
} }
catch (ServletException | IOException e) { catch (ServletException | IOException e) {
fail("Exception encountered.", e); fail("Exception encountered.", e);
@ -252,31 +205,34 @@ public class QuotaCheckServletFilterTest {
// Mock request // Mock request
MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletRequest request = new MockHttpServletRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
quotaCheckWithRateLimiterLimitedFallbackFilter.init(); quotaCheckWithRateLimiterLimitedFallbackFilter.init();
try { try {
// Pass // Pass
MetadataContext.LOCAL_SERVICE = "TestApp1"; MetadataContext.LOCAL_SERVICE = "TestApp1";
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain); MockHttpServletResponse testApp1response = new MockHttpServletResponse();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp1response, filterChain);
// Unirate waiting 1000ms // Unirate waiting 1000ms
MetadataContext.LOCAL_SERVICE = "TestApp2"; MetadataContext.LOCAL_SERVICE = "TestApp2";
MockHttpServletResponse testApp2response = new MockHttpServletResponse();
long startTimestamp = System.currentTimeMillis(); long startTimestamp = System.currentTimeMillis();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain); quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp2response, filterChain);
assertThat(System.currentTimeMillis() - startTimestamp).isGreaterThanOrEqualTo(1000L); assertThat(System.currentTimeMillis() - startTimestamp).isGreaterThanOrEqualTo(1000L);
// Rate limited // Rate limited
MetadataContext.LOCAL_SERVICE = "TestApp3"; MetadataContext.LOCAL_SERVICE = "TestApp3";
MockHttpServletResponse testApp3response = new MockHttpServletResponse();
String contentType = new MediaType(polarisRateLimiterLimitedFallback.mediaType(), polarisRateLimiterLimitedFallback.charset()).toString(); String contentType = new MediaType(polarisRateLimiterLimitedFallback.mediaType(), polarisRateLimiterLimitedFallback.charset()).toString();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain); quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp3response, filterChain);
assertThat(response.getStatus()).isEqualTo(polarisRateLimiterLimitedFallback.rejectHttpCode()); assertThat(testApp3response.getStatus()).isEqualTo(polarisRateLimiterLimitedFallback.rejectHttpCode());
assertThat(response.getContentAsString()).isEqualTo(polarisRateLimiterLimitedFallback.rejectTips()); assertThat(testApp3response.getContentAsString()).isEqualTo(polarisRateLimiterLimitedFallback.rejectTips());
assertThat(response.getContentType()).isEqualTo(contentType); assertThat(testApp3response.getContentType()).isEqualTo(contentType);
// Exception // Exception
MetadataContext.LOCAL_SERVICE = "TestApp4"; MetadataContext.LOCAL_SERVICE = "TestApp4";
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, response, filterChain); MockHttpServletResponse testApp4response = new MockHttpServletResponse();
quotaCheckWithRateLimiterLimitedFallbackFilter.doFilterInternal(request, testApp4response, filterChain);
} }
catch (ServletException | IOException e) { catch (ServletException | IOException e) {
fail("Exception encountered.", e); fail("Exception encountered.", e);

@ -0,0 +1,109 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilterTest;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelReactiveResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.web.server.ServerWebExchange;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ExtendWith(SpringExtension.class)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
classes = RateLimitRuleArgumentReactiveResolverTest.TestApplication.class,
properties = {
"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"
})
public class RateLimitRuleArgumentReactiveResolverTest {
private final PolarisRateLimiterLabelReactiveResolver labelResolver =
exchange -> Collections.singletonMap("xxx", "xxx");
private RateLimitRuleArgumentReactiveResolver rateLimitRuleArgumentReactiveResolver;
@BeforeEach
void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST";
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
this.rateLimitRuleArgumentReactiveResolver = new RateLimitRuleArgumentReactiveResolver(serviceRuleManager, labelResolver);
}
@Test
public void testGetRuleArguments() {
// Mock request
MetadataContext.LOCAL_SERVICE = "Test";
// Mock request
MockServerHttpRequest request = MockServerHttpRequest.get("http://127.0.0.1:8080/test")
.remoteAddress(new InetSocketAddress("127.0.0.1", 8080))
.header("xxx", "xxx")
.queryParam("yyy", "yyy")
.build();
ServerWebExchange exchange = MockServerWebExchange.from(request);
Set<Argument> arguments = rateLimitRuleArgumentReactiveResolver.getArguments(exchange, MetadataContext.LOCAL_NAMESPACE, MetadataContext.LOCAL_SERVICE);
Set<Argument> exceptRes = new HashSet<>();
exceptRes.add(Argument.buildMethod("GET"));
exceptRes.add(Argument.buildHeader("xxx", "xxx"));
exceptRes.add(Argument.buildQuery("yyy", "yyy"));
exceptRes.add(Argument.buildCallerIP("127.0.0.1"));
exceptRes.add(Argument.buildCustom("xxx", "xxx"));
assertThat(arguments).isEqualTo(exceptRes);
}
@SpringBootApplication
protected static class TestApplication {
}
}

@ -0,0 +1,100 @@
/*
* Tencent is pleased to support the open source community by making Spring Cloud Tencent available.
*
* Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the BSD 3-Clause License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://opensource.org/licenses/BSD-3-Clause
*
* Unless required by applicable law or agreed to in writing, software distributed
* under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/
package com.tencent.cloud.polaris.ratelimit.resolver;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.util.JsonFormat;
import com.tencent.cloud.common.metadata.MetadataContext;
import com.tencent.cloud.polaris.context.ServiceRuleManager;
import com.tencent.cloud.polaris.ratelimit.filter.QuotaCheckServletFilterTest;
import com.tencent.cloud.polaris.ratelimit.spi.PolarisRateLimiterLabelServletResolver;
import com.tencent.polaris.ratelimit.api.rpc.Argument;
import com.tencent.polaris.specification.api.v1.traffic.manage.RateLimitProto;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ExtendWith(SpringExtension.class)
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
classes = RateLimitRuleArgumentServletResolverTest.TestApplication.class,
properties = {
"spring.cloud.polaris.namespace=Test", "spring.cloud.polaris.service=TestApp"
})
public class RateLimitRuleArgumentServletResolverTest {
private final PolarisRateLimiterLabelServletResolver labelResolver =
exchange -> Collections.singletonMap("xxx", "xxx");
private RateLimitRuleArgumentServletResolver rateLimitRuleArgumentServletResolver;
@BeforeEach
void setUp() throws InvalidProtocolBufferException {
MetadataContext.LOCAL_NAMESPACE = "TEST";
ServiceRuleManager serviceRuleManager = mock(ServiceRuleManager.class);
RateLimitProto.Rule.Builder ratelimitRuleBuilder = RateLimitProto.Rule.newBuilder();
InputStream inputStream = QuotaCheckServletFilterTest.class.getClassLoader().getResourceAsStream("ratelimit.json");
String json = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)).lines().collect(Collectors.joining(""));
JsonFormat.parser().ignoringUnknownFields().merge(json, ratelimitRuleBuilder);
RateLimitProto.Rule rateLimitRule = ratelimitRuleBuilder.build();
RateLimitProto.RateLimit rateLimit = RateLimitProto.RateLimit.newBuilder().addRules(rateLimitRule).build();
when(serviceRuleManager.getServiceRateLimitRule(anyString(), anyString())).thenReturn(rateLimit);
this.rateLimitRuleArgumentServletResolver = new RateLimitRuleArgumentServletResolver(serviceRuleManager, labelResolver);
}
@Test
public void testGetRuleArguments() {
// Mock request
MetadataContext.LOCAL_SERVICE = "Test";
MockHttpServletRequest request = new MockHttpServletRequest(null, "GET", "/xxx");
request.setParameter("yyy", "yyy");
request.addHeader("xxx", "xxx");
Set<Argument> arguments = rateLimitRuleArgumentServletResolver.getArguments(request, MetadataContext.LOCAL_NAMESPACE, MetadataContext.LOCAL_SERVICE);
Set<Argument> exceptRes = new HashSet<>();
exceptRes.add(Argument.buildMethod("GET"));
exceptRes.add(Argument.buildHeader("xxx", "xxx"));
exceptRes.add(Argument.buildQuery("yyy", "yyy"));
exceptRes.add(Argument.buildCallerIP("127.0.0.1"));
exceptRes.add(Argument.buildCustom("xxx", "xxx"));
assertThat(arguments).isEqualTo(exceptRes);
}
@SpringBootApplication
protected static class TestApplication {
}
}

@ -17,6 +17,8 @@
package com.tencent.cloud.polaris.ratelimit.utils; package com.tencent.cloud.polaris.ratelimit.utils;
import java.util.HashSet;
import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult; import com.tencent.polaris.api.plugin.ratelimiter.QuotaResult;
import com.tencent.polaris.ratelimit.api.core.LimitAPI; import com.tencent.polaris.ratelimit.api.core.LimitAPI;
import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest; import com.tencent.polaris.ratelimit.api.rpc.QuotaRequest;
@ -66,28 +68,28 @@ public class QuotaCheckUtilsTest {
public void testGetQuota() { public void testGetQuota() {
// Pass // Pass
String serviceName = "TestApp1"; String serviceName = "TestApp1";
QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null); QuotaResponse quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk); assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk);
assertThat(quotaResponse.getWaitMs()).isEqualTo(0); assertThat(quotaResponse.getWaitMs()).isEqualTo(0);
assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultOk"); assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultOk");
// Unirate waiting 1000ms // Unirate waiting 1000ms
serviceName = "TestApp2"; serviceName = "TestApp2";
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null); quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk); assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk);
assertThat(quotaResponse.getWaitMs()).isEqualTo(1000); assertThat(quotaResponse.getWaitMs()).isEqualTo(1000);
assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultOk"); assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultOk");
// Rate limited // Rate limited
serviceName = "TestApp3"; serviceName = "TestApp3";
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null); quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultLimited); assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultLimited);
assertThat(quotaResponse.getWaitMs()).isEqualTo(0); assertThat(quotaResponse.getWaitMs()).isEqualTo(0);
assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultLimited"); assertThat(quotaResponse.getInfo()).isEqualTo("QuotaResultLimited");
// Exception // Exception
serviceName = "TestApp4"; serviceName = "TestApp4";
quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, null, null); quotaResponse = QuotaCheckUtils.getQuota(limitAPI, null, serviceName, 1, new HashSet<>(), null);
assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk); assertThat(quotaResponse.getCode()).isEqualTo(QuotaResultCode.QuotaResultOk);
assertThat(quotaResponse.getWaitMs()).isEqualTo(0); assertThat(quotaResponse.getWaitMs()).isEqualTo(0);
assertThat(quotaResponse.getInfo()).isEqualTo("get quota failed"); assertThat(quotaResponse.getInfo()).isEqualTo("get quota failed");

@ -0,0 +1,95 @@
{
"id": "f7560cce829e4d4c8556a6be63539af5",
"service": "xxx",
"namespace": "default",
"subset": {},
"priority": 0,
"resource": "QPS",
"type": "GLOBAL",
"labels": {},
"amounts": [
{
"maxAmount": 1,
"validDuration": "1s",
"precision": null,
"startAmount": null,
"minAmount": null
}
],
"action": "REJECT",
"disable": false,
"report": null,
"ctime": "2022-12-11 21:56:59",
"mtime": "2023-03-10 15:40:33",
"revision": "6eec6f416bee40ecbf664c93add61358",
"service_token": null,
"adjuster": null,
"regex_combine": true,
"amountMode": "GLOBAL_TOTAL",
"failover": "FAILOVER_LOCAL",
"cluster": null,
"method": {
"type": "EXACT",
"value": "/xxx",
"value_type": "TEXT"
},
"arguments": [
{
"type": "CALLER_SERVICE",
"key": "default",
"value": {
"type": "EXACT",
"value": "xxx",
"value_type": "TEXT"
}
},
{
"type": "HEADER",
"key": "xxx",
"value": {
"type": "EXACT",
"value": "xxx",
"value_type": "TEXT"
}
},
{
"type": "QUERY",
"key": "yyy",
"value": {
"type": "EXACT",
"value": "yyy",
"value_type": "TEXT"
}
},
{
"type": "METHOD",
"key": "$method",
"value": {
"type": "EXACT",
"value": "GET",
"value_type": "TEXT"
}
},
{
"type": "CALLER_IP",
"key": "$caller_ip",
"value": {
"type": "EXACT",
"value": "127.0.0.1",
"value_type": "TEXT"
}
},
{
"type": "CUSTOM",
"key": "xxx",
"value": {
"type": "EXACT",
"value": "xxx",
"value_type": "TEXT"
}
}
],
"name": "xxx",
"etime": "2023-03-10 15:40:33",
"max_queue_delay": 30
}

@ -55,7 +55,7 @@ public final class MetadataConstant {
/** /**
* Order of filter. * Order of filter.
*/ */
public static final int WEB_FILTER_ORDER = Ordered.HIGHEST_PRECEDENCE + 13; public static final int WEB_FILTER_ORDER = Ordered.HIGHEST_PRECEDENCE + 9;
/** /**
* Order of MetadataFirstFeignInterceptor. * Order of MetadataFirstFeignInterceptor.
@ -83,6 +83,11 @@ public final class MetadataConstant {
*/ */
public static final String CUSTOM_DISPOSABLE_METADATA = "SCT-CUSTOM-DISPOSABLE-METADATA"; public static final String CUSTOM_DISPOSABLE_METADATA = "SCT-CUSTOM-DISPOSABLE-METADATA";
/**
* Custom Disposable Metadata.
*/
public static final String DEFAULT_DISPOSABLE_METADATA = "SCT-DEFAULT-DISPOSABLE-METADATA";
/** /**
* System Metadata. * System Metadata.
*/ */
@ -93,4 +98,19 @@ public final class MetadataConstant {
*/ */
public static final String METADATA_CONTEXT = "SCT-METADATA-CONTEXT"; public static final String METADATA_CONTEXT = "SCT-METADATA-CONTEXT";
} }
public static class DefaultMetadata {
/**
* Default Metadata Source Service Namespace Key.
*/
public static final String DEFAULT_METADATA_SOURCE_SERVICE_NAMESPACE = "source_service_namespace";
/**
* Default Metadata Source Service Name Key.
*/
public static final String DEFAULT_METADATA_SOURCE_SERVICE_NAME = "source_service_name";
}
} }

@ -121,13 +121,13 @@ public final class MetadataContextHolder {
/** /**
* Save metadata map to thread local. * Save metadata map to thread local.
* @param dynamicTransitiveMetadata custom metadata collection * @param dynamicTransitiveMetadata custom metadata collection
* @param dynamicDisposableMetadata custom disposable metadata connection * @param dynamicCustomDisposableMetadata custom disposable metadata collection
* @param dynamicDefaultDisposableMetadata default disposable metadata collection
*/ */
public static void init(Map<String, String> dynamicTransitiveMetadata, Map<String, String> dynamicDisposableMetadata) { public static void init(Map<String, String> dynamicTransitiveMetadata, Map<String, String> dynamicCustomDisposableMetadata, Map<String, String> dynamicDefaultDisposableMetadata) {
// Init ThreadLocal. // Init ThreadLocal.
MetadataContextHolder.remove(); MetadataContextHolder.remove();
MetadataContext metadataContext = MetadataContextHolder.get(); MetadataContext metadataContext = MetadataContextHolder.get();
// Save transitive metadata to ThreadLocal. // Save transitive metadata to ThreadLocal.
if (!CollectionUtils.isEmpty(dynamicTransitiveMetadata)) { if (!CollectionUtils.isEmpty(dynamicTransitiveMetadata)) {
Map<String, String> staticTransitiveMetadata = metadataContext.getTransitiveMetadata(); Map<String, String> staticTransitiveMetadata = metadataContext.getTransitiveMetadata();
@ -135,13 +135,20 @@ public final class MetadataContextHolder {
mergedTransitiveMetadata.putAll(staticTransitiveMetadata); mergedTransitiveMetadata.putAll(staticTransitiveMetadata);
mergedTransitiveMetadata.putAll(dynamicTransitiveMetadata); mergedTransitiveMetadata.putAll(dynamicTransitiveMetadata);
metadataContext.setTransitiveMetadata(Collections.unmodifiableMap(mergedTransitiveMetadata)); metadataContext.setTransitiveMetadata(Collections.unmodifiableMap(mergedTransitiveMetadata));
}
Map<String, String> mergedDisposableMetadata = new HashMap<>(dynamicDisposableMetadata); if (!CollectionUtils.isEmpty(dynamicCustomDisposableMetadata) || !CollectionUtils.isEmpty(dynamicDefaultDisposableMetadata)) {
metadataContext.setUpstreamDisposableMetadata(Collections.unmodifiableMap(mergedDisposableMetadata)); Map<String, String> mergedUpstreamDisposableMetadata = new HashMap<>();
if (!CollectionUtils.isEmpty(dynamicCustomDisposableMetadata)) {
mergedUpstreamDisposableMetadata.putAll(dynamicCustomDisposableMetadata);
}
if (!CollectionUtils.isEmpty(dynamicDefaultDisposableMetadata)) {
mergedUpstreamDisposableMetadata.putAll(dynamicDefaultDisposableMetadata);
}
metadataContext.setUpstreamDisposableMetadata(Collections.unmodifiableMap(mergedUpstreamDisposableMetadata));
}
Map<String, String> staticDisposableMetadata = metadataContext.getFragmentContext(FRAGMENT_DISPOSABLE); Map<String, String> staticDisposableMetadata = metadataContext.getFragmentContext(FRAGMENT_DISPOSABLE);
metadataContext.setDisposableMetadata(Collections.unmodifiableMap(staticDisposableMetadata)); metadataContext.setDisposableMetadata(Collections.unmodifiableMap(staticDisposableMetadata));
}
MetadataContextHolder.set(metadataContext); MetadataContextHolder.set(metadataContext);
} }

@ -58,7 +58,7 @@ public class MetadataContextHolderTest {
customMetadata.put("a", "1"); customMetadata.put("a", "1");
customMetadata.put("b", "22"); customMetadata.put("b", "22");
customMetadata.put("c", "3"); customMetadata.put("c", "3");
MetadataContextHolder.init(customMetadata, new HashMap<>()); MetadataContextHolder.init(customMetadata, new HashMap<>(), new HashMap<>());
metadataContext = MetadataContextHolder.get(); metadataContext = MetadataContextHolder.get();
customMetadata = metadataContext.getTransitiveMetadata(); customMetadata = metadataContext.getTransitiveMetadata();
Assertions.assertThat(customMetadata.get("a")).isEqualTo("1"); Assertions.assertThat(customMetadata.get("a")).isEqualTo("1");

Loading…
Cancel
Save