1
0
Code Issues Pull Requests Packages Projects Releases Wiki Activity GitHub Gitee

Gateway添加通过IP限流

This commit is contained in:
程序员小墨 2023-04-24 19:01:57 +08:00
parent 2f469aec14
commit 2d82571303
2 changed files with 97 additions and 0 deletions

View File

@ -0,0 +1,79 @@
package com.cxyxiaomo.epp.gateway.Factory;
import com.google.common.util.concurrent.RateLimiter;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Component
// 一个自定义的限流过滤器工厂
public class RateLimitByIpGatewayFilterFactory extends AbstractGatewayFilterFactory<RateLimitByIpGatewayFilterFactory.Config> {
// 用于存储IP地址和对应的计数器
private static final Map<String, RateLimiter> RATE_LIMITER_CACHE = new ConcurrentHashMap<>();
public RateLimitByIpGatewayFilterFactory() {
super(Config.class);
}
@Override
public GatewayFilter apply(Config config) {
return (exchange, chain) -> {
// 获取请求的IP地址
ServerHttpRequest request = exchange.getRequest();
String ip = request.getRemoteAddress().getAddress().getHostAddress();
// 根据IP地址获取对应的限流器
RateLimiter rateLimiter = RATE_LIMITER_CACHE.get(ip);
if (rateLimiter == null) {
// 如果不存在则创建一个新的限流器并放入缓存中
rateLimiter = RateLimiter.create(config.getRate());
RATE_LIMITER_CACHE.put(ip, rateLimiter);
}
// 判断请求是否被限流
if (rateLimiter.tryAcquire(config.getPermits())) {
// 如果没有被限流则放行
return chain.filter(exchange);
} else {
System.out.println("限流ip: " + ip);
// 如果被限流则返回429状态码Too Many Requests
ServerHttpResponse response = exchange.getResponse();
response.setStatusCode(HttpStatus.TOO_MANY_REQUESTS);
return response.setComplete();
}
};
}
// 配置类用于接收配置参数
public static class Config {
// 每秒允许的请求数
private double rate;
// 每次请求需要的令牌数
private int permits;
public double getRate() {
return rate;
}
public void setRate(double rate) {
this.rate = rate;
}
public int getPermits() {
return permits;
}
public void setPermits(int permits) {
this.permits = permits;
}
}
}

View File

@ -40,20 +40,38 @@ spring:
predicates: predicates:
- Path=/user/** - Path=/user/**
- Method=GET,POST - Method=GET,POST
filters: # 路由过滤器,使用自定义的限流过滤器工厂
- name: RateLimitByIp # 设置每秒允许5个请求每次请求需要1个令牌
args:
rate: 5.0
permits: 1
- id: access - id: access
uri: lb://microservice-provider-access uri: lb://microservice-provider-access
predicates: predicates:
- Path=/access/** - Path=/access/**
- Method=GET,POST - Method=GET,POST
filters: # 路由过滤器,使用自定义的限流过滤器工厂
- name: RateLimitByIp # 设置每秒允许5个请求每次请求需要1个令牌
args:
rate: 5.0
permits: 1
- id: access-websocket - id: access-websocket
uri: lb:ws://microservice-provider-access uri: lb:ws://microservice-provider-access
predicates: predicates:
- Path=/access/websocket/** - Path=/access/websocket/**
- id: shop - id: shop
uri: lb://microservice-provider-shop uri: lb://microservice-provider-shop
predicates: predicates:
- Path=/shop/** - Path=/shop/**
- Method=GET,POST - Method=GET,POST
filters: # 路由过滤器,使用自定义的限流过滤器工厂
- name: RateLimitByIp # 设置每秒允许5个请求每次请求需要1个令牌
args:
rate: 5.0
permits: 1
- id: test1 - id: test1
uri: lb://microservice-provider-test uri: lb://microservice-provider-test