Skip to content
微信扫码关注公众号

Spring Boot 中使用自定义 Validator 校验两个参数的大小关系

  1. 创建一个自定义的注解 @RangeCompare

    java
    import javax.validation.Constraint;
    import javax.validation.Payload;
    import java.lang.annotation.*;
    
    @Target({ElementType.TYPE, ElementType.ANNOTATION_TYPE})
    @Retention(RetentionPolicy.RUNTIME)
    @Constraint(validatedBy = {RangeCompareValidator.class})
    @Documented
    public @interface RangeCompare {
    
        String message() default "起始值必须小于或等于结束值";
    
        Class<?>[] groups() default {};
    
        Class<? extends Payload>[] payload() default {};
    
        String from();
    
        String to();
    
        @Target({ElementType.TYPE, ElementType.ANNOTATION_TYPE})
        @Retention(RetentionPolicy.RUNTIME)
        @Documented
        public @interface List {
            RangeCompare[] value();
        }
    }

    其中最重要的是 @Constraint 注解,用于指示该注解包含哪些验证逻辑。

    @Constraint 注解的源码:

    java
    @Documented
    @Target({ ANNOTATION_TYPE })
    @Retention(RUNTIME)
    public @interface Constraint {
        Class<? extends ConstraintValidator<?, ?>>[] validatedBy();
    }
  2. 创建 @Constraint 注解中指定的约束验证器 RangeCompareValidator

    java
    import org.springframework.beans.BeanWrapper;
    import org.springframework.beans.BeanWrapperImpl;
    
    import javax.validation.ConstraintValidator;
    import javax.validation.ConstraintValidatorContext;
    import java.time.LocalDate;
    import java.time.LocalDateTime;
    import java.time.LocalTime;
    
    public class RangeCompareValidator implements ConstraintValidator<RangeCompare, Object> {
    
        private String from;
        private String to;
    
        @Override
        public void initialize(RangeCompare constraint) {
            from = constraint.from();
            to = constraint.to();
        }
    
        @Override
        public boolean isValid(Object value, ConstraintValidatorContext context) {
    
            BeanWrapper beanWrapper = new BeanWrapperImpl(value);
            Object fromValue = beanWrapper.getPropertyValue(from);
            Object toValue = beanWrapper.getPropertyValue(to);
    
            if (fromValue == null || toValue == null) {
                return true;
            }
    
            if (fromValue instanceof Number && toValue instanceof Number) {
                return ((Number) fromValue).doubleValue() <= ((Number) toValue).doubleValue();
            }
    
            if (fromValue instanceof LocalDate && toValue instanceof LocalDate) {
                return !((LocalDate) fromValue).isAfter(((LocalDate) toValue));
            }
    
            if (fromValue instanceof LocalDateTime && toValue instanceof LocalDateTime) {
                return !((LocalDateTime) fromValue).isAfter(((LocalDateTime) toValue));
            }
    
            if (fromValue instanceof LocalTime && toValue instanceof LocalTime) {
                return !((LocalTime) fromValue).isAfter(((LocalTime) toValue));
            }
    
            throw new IllegalArgumentException("只支持数字或日期类型的比较");
        }
    }

    ConstraintValidator 接口定义如下:

    java
    package javax.validation;
    
    import java.lang.annotation.Annotation;
    
    import javax.validation.constraintvalidation.SupportedValidationTarget;
    
    public interface ConstraintValidator<A extends Annotation, T> {
    
        default void initialize(A constraintAnnotation) {}
    
        boolean isValid(T value, ConstraintValidatorContext context);
    }

    ConstraintValidator 接口支持两个泛型参数:

    • A extends Annotation : 验证对应的注解类型
    • T :验证的目标类型
  3. 添加全局的异常处理器

    参数验证不通过时会被 @ExceptionHandler(MethodArgumentNotValidException.class) 所捕获,这样就不必在每个接口中处理验证结果了。

    java
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.context.support.DefaultMessageSourceResolvable;
    import org.springframework.http.HttpStatus;
    import org.springframework.web.bind.MethodArgumentNotValidException;
    import org.springframework.web.bind.annotation.ControllerAdvice;
    import org.springframework.web.bind.annotation.ExceptionHandler;
    import org.springframework.web.bind.annotation.ResponseBody;
    import org.springframework.web.bind.annotation.ResponseStatus;
    
    import java.util.stream.Collectors;
    import java.util.stream.Stream;
    
    @ControllerAdvice
    @Slf4j
    public class GlobalExceptionAdvice {
    
        @ExceptionHandler(MethodArgumentNotValidException.class)
        @ResponseStatus(HttpStatus.OK)
        @ResponseBody
        public ResponseInfo<Object> handleMethodArgumentNotValidException(MethodArgumentNotValidException ex) {
            String errors = Stream.concat(
                            ex.getBindingResult().getFieldErrors().stream().map(m -> m.getField() + ":" + m.getDefaultMessage()),
                            ex.getBindingResult().getGlobalErrors().stream().map(DefaultMessageSourceResolvable::getDefaultMessage))
                    .collect(Collectors.joining(","));
            log.warn("参数错误[{}]", errors);
            return ResponseInfo.buildError(errors);
        }
    }
  4. 参数类上添加 @RangeCompare 注解

    java
    import io.swagger.annotations.ApiModelProperty;
    import lombok.Data;
    
    import java.time.LocalDate;
    
    @Data
    @RangeCompare(from = "fromDate", to = "toDate", message = "起始日期必须小于或等于结束日期")
    public class TimeRangeParam  {
    
        @ApiModelProperty("起始日期")
        private LocalDate fromDate;
    
        @ApiModelProperty("结束日期")
        private LocalDate toDate;
    
    }
  5. 接口参数上添加 @Validated 注解

    java
    @PostMapping(value = "search")
    public ResponseInfo<Object> search(@Validated @RequestBody TimeRangeParam req) {
        // do something
    }