实现一个自定义的protoc插件
October 1, 2023
我们使用protobuf+grpc技术栈来开发微服务时,会要使用相关protoc插件来生成相关代码。有时可能会需要自定义一些插件,本文就来实现一个自定义的protoc插件。
新旧接口的说明 #
以前开发protoc插件时,需要实现generator接口(github.com/golang/protobuf/protoc-gen-go/generator),现在网上有不少稍老一些资料也是这样介绍的。但是实际上,这个接口已经被废弃了,现在要开发插件,应该使用的是"google.golang.org/protobuf/compiler/protogen"包。我们开发插件时,必须要开发一个如下签名的函数:
func(*Plugin) error
代码实现 #
废话不多说,这里直接上代码main.go
package main
import (
"google.golang.org/protobuf/compiler/protogen"
"strconv"
"strings"
)
func myPlugin(p *protogen.Plugin) error {
// 插件的代码在这里实现
for _, f := range p.Files {
if !f.Generate {
continue
}
generateFile(p, f)
}
return nil
}
func generateFile(p *protogen.Plugin, f *protogen.File) {
g := p.NewGeneratedFile(f.GeneratedFilenamePrefix+".pb.myplugin.go", f.GoImportPath)
g.P("// Code generated by protoc-gen-myplugin. DO NOT EDIT.")
g.P()
g.P("package ", f.GoPackageName)
g.P()
g.P("import (")
g.P(" \"errors\"")
g.P(" \"strings\"")
g.P(" \"strconv\"")
g.P(" \"regexp\"")
g.P(")")
g.P()
g.P(validatorTpl)
for _, service := range f.Services {
for _, method := range service.Methods {
// 生成一个validator,对方法入参中的每个字段进行校验
g.P("// Validate 参数校验")
g.P("func (x *", method.Input.GoIdent.GoName, ") Validate() error {")
for i, field := range method.Input.Fields {
g.P("validator" + strconv.Itoa(i) + " := NewValidator(`" + strings.TrimSpace(field.Comments.Leading.String()) + "`)")
g.P("if err := validator" + strconv.Itoa(i) + ".Validate(x." + field.GoName + ");err!= nil {")
g.P(" return err")
g.P("}")
g.P()
}
g.P(" return nil")
g.P("}")
g.P()
}
}
g.P()
}
const validatorTpl = `type Validator struct {
fieldDesc string // 字段描述,用于提示
fieldLengthLt int // 字段长度最小值
fieldLengthGt int // 字段长度最大值
fieldType string // 特殊字段类型,如mobile,使用内置的方法进行校验
}
// NewValidator 初始化
func NewValidator(fieldComment string) *Validator {
v := &Validator{}
// 过滤前缀和后面的空格
fieldComment = strings.TrimPrefix(fieldComment, "// ")
fieldComment = strings.TrimSpace(fieldComment)
fields := strings.Split(fieldComment, " ")
// 如果字段描述中包含了must,则表示该字段必填
for _, field := range fields {
columns := strings.Split(field, ":")
if len(columns) != 2 {
continue
}
switch columns[0] {
case "type":
v.fieldType = columns[1]
case "desc":
v.fieldDesc = columns[1]
case "length":
lengths := strings.Split(columns[1], "-")
if len(lengths) != 2 {
continue
}
lt, _ := strconv.ParseInt(lengths[0], 10, 64)
gt, _ := strconv.ParseInt(lengths[1], 10, 64)
v.fieldLengthGt = int(gt)
v.fieldLengthLt = int(lt)
}
if columns[0] == "desc" {
v.fieldDesc = columns[1]
continue
}
}
return v
}
// Validate 校验
func (v *Validator) Validate(fieldValue interface{}) error {
// 判断字段类型
switch fieldValue.(type) {
case string:
if v.fieldType == "mobile" {
return v.validateMobile(fieldValue.(string))
}
return v.validateStringLength(fieldValue.(string))
case int64, uint64, int32, uint32:
// todo
}
return nil
}
func (v *Validator) validateStringLength(fieldValue string) error {
if len([]rune(fieldValue)) > v.fieldLengthGt {
return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
}
if len([]rune(fieldValue)) < v.fieldLengthLt {
return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
}
return nil
}
const mobileReg = "^1[3456789]\\d{9}$"
func (v *Validator) validateMobile(fieldValue string) error {
regM := mobileReg
pattern := regexp.MustCompile(regM)
if !pattern.MatchString(fieldValue) {
return errors.New(v.fieldDesc + "格式不正确")
}
return nil
}
`
func main() {
protogen.Options{}.Run(myPlugin)
}
- type: 字段类型,这里指特殊类型,如mobile(手机号),使用内置的方法进行校验
- desc: 字段描述,用于提示 ,如 {desc} 不能为空
- length: 字段长度,如 1-10,表示字段长度在1-10之间
可以看到,我们定义了一个myPlugin函数(注意函数类型),并在main方法中将其传入protogen.Options{}.Run方法中。
编译插件 #
为了方便,我们写一个简单的Makefile文件来完成编译,安装操作:
install:
go build -o protoc-gen-myplugin ./
mv protoc-gen-myplugin /Users/gq/go/bin/
执行make install命令,将会生成protoc-gen-myplugin文件,并将其移动到$GOPATH/bin目录下。
使用插件 #
我们再另外新建一个项目bufdemo,来测试一下生成的插件。 注意:我们这里使用buf来调用相关插件,并生成go代码。
编写proto文件 #
新建一个user.proto文件user.proto
syntax = "proto3";
package pb;
option go_package = "bufdemo/pb";
service UserService {
// 添加用户
rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) {}
}
message CreateUserRequest {
// desc:姓名 length:2-20
string name = 1;
// desc:手机号码 type:mobile
string mobile = 2;
}
message CreateUserResponse {
uint32 id = 2;
}
修改buf.gen.yml
#
version: v1
plugins:
- plugin: go
out: pb
opt:
- paths=source_relative
- plugin: go-grpc
out: pb
opt:
- paths=source_relative
- plugin: myplugin
out: pb
opt:
- paths=source_relative
生成代码 #
执行buf generate命令,将会生成相关代码:
文件内容为:user.pb.myplugin.go
// Code generated by protoc-gen-myplugin. DO NOT EDIT.
package pb
import (
"errors"
"strings"
"strconv"
"regexp"
)
type Validator struct {
fieldDesc string // 字段描述,用于提示
fieldLengthLt int // 字段长度最小值
fieldLengthGt int // 字段长度最大值
fieldType string // 特殊字段类型,如mobile,使用内置的方法进行校验
}
// NewValidator 初始化
func NewValidator(fieldComment string) *Validator {
v := &Validator{}
// 过滤前缀和后面的空格
fieldComment = strings.TrimPrefix(fieldComment, "// ")
fieldComment = strings.TrimSpace(fieldComment)
fields := strings.Split(fieldComment, " ")
// 如果字段描述中包含了must,则表示该字段必填
for _, field := range fields {
columns := strings.Split(field, ":")
if len(columns) != 2 {
continue
}
switch columns[0] {
case "type":
v.fieldType = columns[1]
case "desc":
v.fieldDesc = columns[1]
case "length":
lengths := strings.Split(columns[1], "-")
if len(lengths) != 2 {
continue
}
lt, _ := strconv.ParseInt(lengths[0], 10, 64)
gt, _ := strconv.ParseInt(lengths[1], 10, 64)
v.fieldLengthGt = int(gt)
v.fieldLengthLt = int(lt)
}
if columns[0] == "desc" {
v.fieldDesc = columns[1]
continue
}
}
return v
}
// Validate 校验
func (v *Validator) Validate(fieldValue interface{}) error {
// 判断字段类型
switch fieldValue.(type) {
case string:
if v.fieldType == "mobile" {
return v.validateMobile(fieldValue.(string))
}
return v.validateStringLength(fieldValue.(string))
case int64, uint64, int32, uint32:
// todo
}
return nil
}
func (v *Validator) validateStringLength(fieldValue string) error {
if len([]rune(fieldValue)) > v.fieldLengthGt {
return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
}
if len([]rune(fieldValue)) < v.fieldLengthLt {
return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
}
return nil
}
const mobileReg = "^1[3456789]\\d{9}$"
func (v *Validator) validateMobile(fieldValue string) error {
regM := mobileReg
pattern := regexp.MustCompile(regM)
if !pattern.MatchString(fieldValue) {
return errors.New(v.fieldDesc + "格式不正确")
}
return nil
}
// Validate 参数校验
func (x *CreateUserRequest) Validate() error {
validator0 := NewValidator(`// desc:姓名 length:2-20`)
if err := validator0.Validate(x.Name); err != nil {
return err
}
validator1 := NewValidator(`// desc:手机号码 type:mobile`)
if err := validator1.Validate(x.Mobile); err != nil {
return err
}
return nil
}
编写服务端文件 #
新建server目录,并在server目录下新建一个server.go文件,内容如下:server.go
package main
import (
"bufdemo/pb"
"context"
"google.golang.org/genproto/googleapis/rpc/code"
"google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"log"
"net"
)
type UserServiceImpl struct {
pb.UnimplementedUserServiceServer
}
func (u *UserServiceImpl) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
resp := &pb.CreateUserResponse{}
if err := request.Validate(); err != nil {
return nil, status.Error(codes.Code(code.Code_INVALID_ARGUMENT), err.Error())
}
resp.Id = 1
return resp, nil
}
func main() {
lis, err := net.Listen("tcp", ":8091")
if err != nil {
log.Fatalf("failed to listen:%v", err)
}
s := grpc.NewServer()
pb.RegisterUserServiceServer(s, &UserServiceImpl{})
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve:%v", err)
}
}
运行服务端:
go run server/server.go
使用postman测试 #
我们这里使用postman来测试一下:
- 测试1:姓名长度不正确:
测试1

- 测试2:手机号码格式不正确:
测试2

可以看到,我们的校验器已经生效了。