C++ Protobuf实现接口参数自动校验详解

2023-05-16 20:05:43 接口 校验 详解

1、背景

c++做业务发开的同学是否还在不厌其烦的编写大量if-else模块来做接口参数校验呢?当接口字段数量多大几十个,这样的参数校验代码都能多达上百行,甚至超过了接口业务逻辑的代码体量,而且随着业务迭代,接口增加了新的字段,又不得不再加几个if-else,对于有Java、python开发经历的同学,对这种原始的参数校验方法必定是嗤之以鼻。今天,我们就模拟Java里面通过注解实现参数校验的方式来针对C++ protobuf接口实现一个更加方便、快捷的参数校验自动工具

2、方案简介

实现基本思路主要用到两个核心技术点:protobuf字段属性扩展和反射机制。

首先针对常用的协议字段数据类型(int32、int64、uint32、uint64、float、double、string、array、enum)定义了一套最常用的字段校验规则,如下表:

每个校验规则的protobuf定义如下:

// int32类型校验规则
message Int32Rule {
    oneof lt_rule {
        int32 lt = 1;
    }
    oneof lte_rule {
        int32 lte = 2;
    }
    oneof gt_rule {
        int32 gt = 3;
    }
    oneof gte_rule {
        int32 gte = 4;
    }
    repeated int32 in = 5;
    repeated int32 not_in = 6;
}

// int64类型校验规则
message Int64Rule {
    oneof lt_rule {
        int64 lt = 1;
    }
    oneof lte_rule {
        int64 lte = 2;
    }
    oneof gt_rule {
        int64 gt = 3;
    }
    oneof gte_rule {
        int64 gte = 4;
    }
    repeated int64 in = 5;
    repeated int64 not_in = 6;
}

// uint32类型校验规则
message UInt32Rule {
    oneof lt_rule {
        uint32 lt = 1;
    }
    oneof lte_rule {
        uint32 lte = 2;
    }
    oneof gt_rule {
        uint32 gt = 3;
    }
    oneof gte_rule {
        uint32 gte = 4;
    }
    repeated uint32 in = 5;
    repeated uint32 not_in = 6;
}

// uint64类型校验规则
message UInt64Rule {
    oneof lt_rule {
        uint64 lt = 1;
    }
    oneof lte_rule {
        uint64 lte = 2;
    }
    oneof gt_rule {
        uint64 gt = 3;
    }
    oneof gte_rule {
        uint64 gte = 4;
    }
    repeated uint64 in = 5;
    repeated uint64 not_in = 6;
}

// float类型校验规则
message FloatRule {
    oneof lt_rule {
        float lt = 1;
    }
    oneof lte_rule {
        float lte = 2;
    }
    oneof gt_rule {
        float gt = 3;
    }
    oneof gte_rule {
        float gte = 4;
    }
    repeated float in = 5;
    repeated float not_in = 6;
}

// double类型校验规则
message DoubleRule {
    oneof lt_rule {
        double lt = 1;
    }
    oneof lte_rule {
        double lte = 2;
    }
    oneof gt_rule {
        double gt = 3;
    }
    oneof gte_rule {
        double gte = 4;
    }
    repeated double in = 5;
    repeated double not_in = 6;
}

// string类型校验规则
message StringRule {
    bool not_empty = 1;
    oneof min_len_rule {
        uint32 min_len = 2;
    }
    oneof max_len_rule {
        uint32 max_len = 3;
    }
    string regex_pattern = 4;
}

// enum类型校验规则
message EnumRule {
    repeated int32 in = 1;
}

// array(数组)类型校验规则
message ArrayRule {
    bool not_empty = 1;
    oneof min_len_rule {
        uint32 min_len = 2;
    }
    oneof max_len_rule {
        uint32 max_len = 3;
    }
}

注意:校验规则中一些字段通过oneof关键字包装了一层,主要是因为protobuf3中全部字段都默认是optional的,即即使不显示设置其值,protobuf也会给它一个默认值,如数值类型的一般默认值就是0,这样当某个规则的值(如lt)为0的时候,我们无法确定是没有设置值还是就是设置的0,加了oneof后可以通过oneof字段的xxx_case方法来判断对应值是否有人为设定。

上述规则被划分为4大类:数值类规则(Int32Rule、Int64Rule、UInt32Rule、UInt64Rule、FloatRule、DoubleRule)、字符串类规则(StringRule)、枚举类规则(EnumRule)、数组类规则(ArrayRule), 每一类后续都会有一个对应的校验器(参数校验算法)。

然后,拓展protobuf字段属性(Google.protobuf.FieldOptions),将字段校验规则拓展为字段属性之一。如下图:扩展字段属性名为Rule, 其类型为ValidateRules,其具体校验规则通过oneof关键字限定至多为上述9种校验规则之一(针对某一个字段,其类型唯一,从而其校验规则也是确定的)。

// 校验规则(oneof取上述字段类型校验规则之一)
message ValidateRules {
    oneof rule {
        
        Int32Rule int32  = 1;
        Int64Rule int64  = 2;
        UInt32Rule uint32  = 3;
        UInt64Rule uint64  = 4;
        FloatRule float = 5;
        DoubleRule double = 6;
        StringRule string = 7;


        
        EnumRule enum = 8;
        ArrayRule array = 9;
    }
}

// 拓展默认字段属性, 将ValidateRules设置为字段属性
extend google.protobuf.FieldOptions {
    ValidateRules Rule = 10000;
}

上述校验规则和字段属性扩展定义在validator.proto文件中,使用时通过import导入该proto文件便可以使用上述扩展字段属性用于定义字段,如:

说明: 上述接口定义中,通过扩展字段属性validator.Rule(其内容为上述定义9中类型校验规则之一)限制了用户年龄age字段值必须小于等于(lte)150;名字name字段不能为空且长度不能大于32;手机号字段phone不能为空且必须满足指定的手机号正则表达式规则;邮件字段允许为空(默认)但如果有传入值的话则必须满足对应邮件正则表达式规则;others数组字段不允许为空,且长度不小于2。

有了上述接口字段定义后,需要校验的字段都已经带上了validator.Rule属性,其中已包含了对应字段的校验规则,接下来需要实现一个参数自动校验算法, 基本思路就是通过反射逐个获取待校验Message结构体中各个字段值及其字段属性中校验规则validator.Rule,然后逐一匹配字段值是否满足每一项规则定义,不满足则返回FALSE;对于嵌套结构体类型则做递归校验,算法流程及实现如下:

#pragma once

#include <google/protobuf/message.h>
#include <butil/logging.h>
#include <regex>
#include <algorithm>
#include <sstream>
#include "proto/validator.pb.h"

namespace validator {

using namespace google::protobuf;


typedef float float_;
typedef double double_;


#define NumericalValidator(pb_cpptype, method_type, value_type)                                    \
    case google::protobuf::FieldDescriptor::CPPTYPE_##pb_cpptype: {                                \
        if (validate_rules.has_##value_type()) {                                                   \
            const method_type##Rule& rule = validate_rules.value_type();                           \
            value_type value              = reflection->Get##method_type(message, field);          \
            if ((rule.lt_rule_case() && value >= rule.lt()) ||                                     \
                (rule.lte_rule_case() && value > rule.lte()) ||                                    \
                (rule.gt_rule_case() && value <= rule.gt()) ||                                     \
                (rule.gte_rule_case() && value < rule.gte())) {                                    \
                std::ostringstream os;                                                             \
                os << field->full_name() << " value out of range.";                                \
                return {false, os.str()};                                                          \
            }                                                                                      \
            if ((!rule.in().empty() &&                                                             \
                 std::find(rule.in().begin(), rule.in().end(), value) == rule.in().end()) ||       \
                (!rule.not_in().empty() &&                                                         \
                 std::find(rule.not_in().begin(), rule.not_in().end(), value) !=                   \
                     rule.not_in().end())) {                                                       \
                std::ostringstream os;                                                             \
                os << field->full_name() << " value not allowed.";                                 \
                return {false, os.str()};                                                          \
            }                                                                                      \
        }                                                                                          \
        break;                                                                                     \
    }


#define StringValidator(pb_cpptype, method_type, value_type)                                       \
    case google::protobuf::FieldDescriptor::CPPTYPE_##pb_cpptype: {                                \
        if (validate_rules.has_##value_type()) {                                                   \
            const method_type##Rule& rule = validate_rules.value_type();                           \
            const value_type& value       = reflection->Get##method_type(message, field);          \
            if (rule.not_empty() && value.empty()) {                                               \
                std::ostringstream os;                                                             \
                os << field->full_name() << " can not be empty.";                                  \
                return {false, os.str()};                                                          \
            }                                                                                      \
            if ((rule.min_len_rule_case() && value.length() < rule.min_len()) ||                   \
                (rule.max_len_rule_case() && value.length() > rule.max_len())) {                   \
                std::ostringstream os;                                                             \
                os << field->full_name() << " length out of range.";                               \
                return {false, os.str()};                                                          \
            }                                                                                      \
            if (!value.empty() && !rule.regex_pattern().empty()) {                                 \
                std::regex ex(rule.regex_pattern());                                               \
                if (!regex_match(value, ex)) {                                                     \
                    std::ostringstream os;                                                         \
                    os << field->full_name() << " fORMat invalid.";                                \
                    return {false, os.str()};                                                      \
                }                                                                                  \
            }                                                                                      \
        }                                                                                          \
        break;                                                                                     \
    }


#define EnumValidator(pb_cpptype, method_type, value_type)                                          \
    case google::protobuf::FieldDescriptor::CPPTYPE_##pb_cpptype: {                                 \
        if (validate_rules.has_##value_type()) {                                                    \
            const method_type##Rule& rule = validate_rules.value_type();                            \
            int value                     = reflection->Get##method_type(message, field)->number(); \
            if (!rule.in().empty() &&                                                               \
                std::find(rule.in().begin(), rule.in().end(), value) == rule.in().end()) {          \
                std::ostringstream os;                                                              \
                os << field->full_name() << " value not allowed.";                                  \
                return {false, os.str()};                                                           \
            }                                                                                       \
        }                                                                                           \
        break;                                                                                      \
    }


#define ArrayValidator()                                                                           \
    uint32 arr_len = (uint32)reflection->FieldSize(message, field);                                \
    if (validate_rules.has_array()) {                                                              \
        const ArrayRule& rule = validate_rules.array();                                            \
        if (rule.not_empty() && arr_len == 0) {                                                    \
            std::ostringstream os;                                                                 \
            os << field->full_name() << " can not be empty.";                                      \
            return {false, os.str()};                                                              \
        }                                                                                          \
        if ((rule.min_len() != 0 && arr_len < rule.min_len()) ||                                   \
            (rule.max_len() != 0 && arr_len > rule.max_len())) {                                   \
            std::ostringstream os;                                                                 \
            os << field->full_name() << " length out of range.";                                   \
            return {false, os.str()};                                                              \
        }                                                                                          \
    }                                                                                              \
                                                                                                   \
                       \
    if (field_type == FieldDescriptor::CPPTYPE_MESSAGE) {                                          \
        for (uint32 i = 0; i < arr_len; i++) {                                                     \
            const Message& sub_message = reflection->GetRepeatedMessage(message, field, i);        \
            ValidateResult&& result    = Validate(sub_message);                                    \
            if (!result.is_valid) {                                                                \
                return result;                                                                     \
            }                                                                                      \
        }                                                                                          \
    }


#define MessageValidator()                                                                         \
    case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {                                     \
        const Message& sub_message = reflection->GetMessage(message, field);                       \
        ValidateResult&& result    = Validate(sub_message);                                        \
        if (!result.is_valid) {                                                                    \
            return result;                                                                         \
        }                                                                                          \
        break;                                                                                     \
    }
    
class ValidatorUtil {
public:
    struct ValidateResult {
        bool is_valid;
        std::string msg;
    };

    static ValidateResult Validate(const Message& message) {
        const Descriptor* descriptor = message.GetDescriptor();
        const Reflection* reflection = message.GetReflection();

        for (int i = 0; i < descriptor->field_count(); i++) {
            const FieldDescriptor* field        = descriptor->field(i);
            FieldDescriptor::CppType field_type = field->cpp_type();
            const ValidateRules& validate_rules = field->options().GetExtension(validator::Rule);

            if (field->is_repeated()) {
                // 数组类型校验
                ArrayValidator();
            } else {
                // 非数组类型,直接调用对应类型校验器
                switch (field_type) {
                    NumericalValidator(INT32, Int32, int32);
                    NumericalValidator(INT64, Int64, int64);
                    NumericalValidator(UINT32, UInt32, uint32);
                    NumericalValidator(UINT64, UInt64, uint64);
                    NumericalValidator(FLOAT, Float, float_);
                    NumericalValidator(DOUBLE, Double, double_);
                    StringValidator(STRING, String, string);
                    EnumValidator(ENUM, Enum, enum_);
		    MessageValidator();
                    default:
                        break;
                }
            }
        }
        return {true, ""};
    }
};

} // namespace validator

3、 使用

整个算法实现相当轻量,规则定义不到200行,算法实现(也即规则解析)不到200行。使用方法也非常简便,只需要在业务proto中import导入validator.proto即可以使用规则定义,然后在业务接口代码中include<validator_util.h>即可使用规则校验工具类对接口参数做自动校验, 以后接口参数校验只需要下面几行就行了(终于不用再写一大堆if_else了)如下:

4、测试

以上就是C++ Protobuf实现接口参数自动校验详解的详细内容,更多关于C++ Protobuf接口参数校验的资料请关注其它相关文章!

相关文章