Skip to content

Spring Boot - Filter 示例

🏷️ Spring Boot

Filter 示例

java
import com.mokasz.zy.server.game.app.core.filter.BodyReaderHttpServletRequestWrapper;
import com.mokasz.zy.server.game.app.core.filter.BodyRewriterHttpServletRequestWrapper;
import com.mokasz.zy.server.game.app.core.utils.HttpUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.PathMatcher;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.SortedMap;

/**
 * Filter 示例
 */
@Component
@Order(0)
@RequiredArgsConstructor
@Slf4j
public class SampleFilter implements Filter {
    // 需要执行当前 filter 处理的请求地址(支持 ** 通配符)
    // 下面的示例表示匹配所有以 /api/test/ 开头的请求
    private final static List<String> FILTER_URLS = Arrays.asList(
            "/api/test/**"
    );
    
    private boolean isNeedDecrypt = false;

    @Override
    public void init(FilterConfig filterConfig) {
        log.info("init SampleFilter");
    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
            throws IOException, ServletException {
        HttpServletResponse response = (HttpServletResponse) res;
        HttpServletRequest request = (HttpServletRequest) req;

        // 验证请求地址是否需要执行过滤处理
        if (!doFilter(request)) {
            chain.doFilter(request, response);
            return;
        }

        // 因为消息内容只能读取一次,为防止丢失,需要通过自定义的 HttpServletRequestWrapper 包装一下
        HttpServletRequest requestWrapper = new BodyReaderHttpServletRequestWrapper(request);
        SortedMap<String, String> allParams = HttpUtil.getAllParams(requestWrapper);

        // 验证请求参数,不通过时直接返回 500
        if (isWrongRequest(allParams)) {
            response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
            return;
        }

        if (isNeedDecrypt) {
            // 需要将参数解密后再放回请求时,可以参考下面的写法
            String decryptResult = decryptParams(allParams);
            BodyRewriterHttpServletRequestWrapper rewriteWrapper = new BodyRewriterHttpServletRequestWrapper(requestWrapper, decryptResult);
            chain.doFilter(rewriteWrapper, response);
        } else {
            chain.doFilter(requestWrapper, response);
        }
    }

    private String decryptParams(SortedMap<String, String> allParams) {
        // decrypt parameters
        return StringUtils.EMPTY;
    }

    private boolean isWrongRequest(SortedMap<String, String> allParams) {
        // do something here
        return false;
    }

    private boolean doFilter(HttpServletRequest request) {
        // 跨域时,有些浏览器会先发送要给 OPTIONS 类型的请求。为防止出错,此时跳过 filter 处理
        if (request.getMethod().equalsIgnoreCase("OPTIONS")) return false;
        // 判断当前请求是否需要执行当前 filter 处理
        String requestURI = request.getRequestURI();
        PathMatcher matcher = new AntPathMatcher();
        for (String patternUrl : FILTER_URLS) {
            if (matcher.match(patternUrl, requestURI)) return true;
        }
        return false;
    }

    @Override
    public void destroy() {
        log.info("destroy SampleFilter");
    }
}

自定义的 HttpServletRequestWrapper

BodyReaderHttpServletRequestWrapper

这个类摘自 xuanweiyao 的博客(具体的地址找不到了),用来保存请求的内容。

java
import java.io.*;
import java.nio.charset.Charset;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

/**
 * 保存过滤器里面的流
 * 
 * @author xuanweiyao
 * @date 10:03 2019/5/30
 */
public class BodyReaderHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private final byte[] body;

    public BodyReaderHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
        String sessionStream = getBodyString(request);
        body = sessionStream.getBytes(Charset.forName("UTF-8"));
    }

    /**
     * 获取请求 Body
     *
     * @param request
     * @return
     */
    public String getBodyString(final ServletRequest request) {
        StringBuilder sb = new StringBuilder();
        try (InputStream inputStream = cloneInputStream(request.getInputStream());
            BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, Charset.forName("UTF-8")))) {
            String line;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return sb.toString();
    }

    /**
     * Description: 复制输入流</br>
     *
     * @param inputStream
     * @return</br>
     */
    public InputStream cloneInputStream(ServletInputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len;
        try {
            while ((len = inputStream.read(buffer)) > -1) {
                byteArrayOutputStream.write(buffer, 0, len);
            }
            byteArrayOutputStream.flush();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return new ByteArrayInputStream(byteArrayOutputStream.toByteArray());
    }

    @Override
    public BufferedReader getReader() {

        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() {

        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
        return new ServletInputStream() {

            @Override
            public int read() {

                return bais.read();
            }

            @Override
            public boolean isFinished() {

                return false;
            }

            @Override
            public boolean isReady() {

                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }
        };
    }
}

BodyRewriterHttpServletRequestWrapper

这个是参照上面的代码写了一个重写请求的 wrapper ,用于需要改写请求内容的场景。

java
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;

public class BodyRewriterHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private final byte[] body;

    public BodyRewriterHttpServletRequestWrapper(HttpServletRequest request, String newbody) {
        super(request);
        body = newbody.getBytes(Charset.forName("UTF-8"));
    }

    @Override
    public BufferedReader getReader() {

        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() {

        final ByteArrayInputStream bais = new ByteArrayInputStream(body);
        return new ServletInputStream() {

            @Override
            public int read() {

                return bais.read();
            }

            @Override
            public boolean isFinished() {

                return false;
            }

            @Override
            public boolean isReady() {

                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }
        };
    }
}

Filter 执行顺序

可以通过 @Order(0) 注解修改 filter 的执行顺序。

未设置时默认的 order 值为 Integer.MAX_VALUE (即 2147483647)。

order 值的范围为 Integer.MIN_VALUEInteger.MAX_VALUE 之间。

order 的值越小,执行的循序越靠前。

order 的值相同时,执行的顺序貌似并不固定。根据遇到的情况来看,会在应用启动时确定一个随机的顺序,然后会一直使用这个顺序来执行。