实现一个自定义的protoc插件

实现一个自定义的protoc插件

October 1, 2023
微服务
Go, protobuf, grpc

我们使用protobuf+grpc技术栈来开发微服务时,会要使用相关protoc插件来生成相关代码。有时可能会需要自定义一些插件,本文就来实现一个自定义的protoc插件。

新旧接口的说明 #

以前开发protoc插件时,需要实现generator接口(github.com/golang/protobuf/protoc-gen-go/generator),现在网上有不少稍老一些资料也是这样介绍的。但是实际上,这个接口已经被废弃了,现在要开发插件,应该使用的是"google.golang.org/protobuf/compiler/protogen"包。我们开发插件时,必须要开发一个如下签名的函数:

func(*Plugin) error

代码实现 #

废话不多说,这里直接上代码

main.go
package main

import (
	"google.golang.org/protobuf/compiler/protogen"
	"strconv"
	"strings"
)

func myPlugin(p *protogen.Plugin) error {
	// 插件的代码在这里实现
	for _, f := range p.Files {
		if !f.Generate {
			continue
		}

		generateFile(p, f)
	}
	return nil
}

func generateFile(p *protogen.Plugin, f *protogen.File) {
	g := p.NewGeneratedFile(f.GeneratedFilenamePrefix+".pb.myplugin.go", f.GoImportPath)
	g.P("// Code generated by protoc-gen-myplugin. DO NOT EDIT.")
	g.P()
	g.P("package ", f.GoPackageName)

	g.P()

	g.P("import (")
	g.P(" \"errors\"")
	g.P(" \"strings\"")
	g.P(" \"strconv\"")
	g.P(" \"regexp\"")
	g.P(")")
	g.P()

	g.P(validatorTpl)

	for _, service := range f.Services {
		for _, method := range service.Methods {
			// 生成一个validator,对方法入参中的每个字段进行校验
			g.P("// Validate 参数校验")
			g.P("func (x *", method.Input.GoIdent.GoName, ") Validate() error {")

			for i, field := range method.Input.Fields {
				g.P("validator" + strconv.Itoa(i) + " := NewValidator(`" + strings.TrimSpace(field.Comments.Leading.String()) + "`)")

				g.P("if err := validator" + strconv.Itoa(i) + ".Validate(x." + field.GoName + ");err!= nil {")
				g.P("  return err")
				g.P("}")
				g.P()
			}
			g.P("  return nil")

			g.P("}")
			g.P()
		}
	}

	g.P()
}

const validatorTpl = `type Validator struct {
	fieldDesc     string // 字段描述,用于提示
	fieldLengthLt int    // 字段长度最小值
	fieldLengthGt int    // 字段长度最大值
	fieldType 	  string // 特殊字段类型,如mobile,使用内置的方法进行校验
}

// NewValidator 初始化
func NewValidator(fieldComment string) *Validator {
	v := &Validator{}
	// 过滤前缀和后面的空格
	fieldComment = strings.TrimPrefix(fieldComment, "// ")
	fieldComment = strings.TrimSpace(fieldComment)

	fields := strings.Split(fieldComment, " ")

	// 如果字段描述中包含了must,则表示该字段必填
	for _, field := range fields {
		columns := strings.Split(field, ":")

		if len(columns) != 2 {
			continue
		}

		switch columns[0] {
		case "type":
			v.fieldType = columns[1]
		case "desc":
			v.fieldDesc = columns[1]
		case "length":
			lengths := strings.Split(columns[1], "-")
			if len(lengths) != 2 {
				continue
			}

			lt, _ := strconv.ParseInt(lengths[0], 10, 64)
			gt, _ := strconv.ParseInt(lengths[1], 10, 64)
			v.fieldLengthGt = int(gt)
			v.fieldLengthLt = int(lt)
		}

		if columns[0] == "desc" {
			v.fieldDesc = columns[1]
			continue
		}
	}

	return v
}

// Validate 校验
func (v *Validator) Validate(fieldValue interface{}) error {
	// 判断字段类型
	switch fieldValue.(type) {
	case string:
		if v.fieldType == "mobile" {
			return v.validateMobile(fieldValue.(string))
		}
		return v.validateStringLength(fieldValue.(string))
	case int64, uint64, int32, uint32:
		// todo
	}

	return nil
}

func (v *Validator) validateStringLength(fieldValue string) error {

	if len([]rune(fieldValue)) > v.fieldLengthGt {
		return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
	}

	if len([]rune(fieldValue)) < v.fieldLengthLt {
		return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
	}

	return nil
}

const mobileReg = "^1[3456789]\\d{9}$"

func (v *Validator) validateMobile(fieldValue string) error {
	regM := mobileReg
	pattern := regexp.MustCompile(regM)
	if !pattern.MatchString(fieldValue) {
		return errors.New(v.fieldDesc + "格式不正确")
	}

	return nil
}
`

func main() {
	protogen.Options{}.Run(myPlugin)
}
在上面,我们实现了一个简单的校验器,可以对方法入参中的每个字段进行校验。可以在字段的注释中添加校验内容:

  • type: 字段类型,这里指特殊类型,如mobile(手机号),使用内置的方法进行校验
  • desc: 字段描述,用于提示 ,如 {desc} 不能为空
  • length: 字段长度,如 1-10,表示字段长度在1-10之间

可以看到,我们定义了一个myPlugin函数(注意函数类型),并在main方法中将其传入protogen.Options{}.Run方法中。

编译插件 #

为了方便,我们写一个简单的Makefile文件来完成编译,安装操作:

install:
	go build -o protoc-gen-myplugin ./
	mv protoc-gen-myplugin /Users/gq/go/bin/

执行make install命令,将会生成protoc-gen-myplugin文件,并将其移动到$GOPATH/bin目录下。

使用插件 #

我们再另外新建一个项目bufdemo,来测试一下生成的插件。 注意:我们这里使用buf来调用相关插件,并生成go代码。

编写proto文件 #

新建一个user.proto文件

user.proto
syntax = "proto3";


package pb;

option go_package = "bufdemo/pb";

service UserService {
  // 添加用户
  rpc CreateUser (CreateUserRequest) returns (CreateUserResponse) {}
}


message CreateUserRequest {
  // desc:姓名 length:2-20
  string name = 1;
  // desc:手机号码 type:mobile
  string mobile = 2;
}

message CreateUserResponse {
  uint32 id = 2;
}

修改buf.gen.yml #

version: v1
plugins:
  - plugin: go
    out: pb
    opt:
      - paths=source_relative
  - plugin: go-grpc
    out: pb
    opt:
      - paths=source_relative
  - plugin: myplugin
    out: pb
    opt:
      - paths=source_relative

生成代码 #

执行buf generate命令,将会生成相关代码: 生成文件 文件内容为:

user.pb.myplugin.go
// Code generated by protoc-gen-myplugin. DO NOT EDIT.

package pb

import (
	"errors"
	"strings"
	"strconv"
	"regexp"
)

type Validator struct {
	fieldDesc     string // 字段描述,用于提示
	fieldLengthLt int    // 字段长度最小值
	fieldLengthGt int    // 字段长度最大值
	fieldType     string // 特殊字段类型,如mobile,使用内置的方法进行校验
}

// NewValidator 初始化
func NewValidator(fieldComment string) *Validator {
	v := &Validator{}
	// 过滤前缀和后面的空格
	fieldComment = strings.TrimPrefix(fieldComment, "// ")
	fieldComment = strings.TrimSpace(fieldComment)

	fields := strings.Split(fieldComment, " ")

	// 如果字段描述中包含了must,则表示该字段必填
	for _, field := range fields {
		columns := strings.Split(field, ":")

		if len(columns) != 2 {
			continue
		}

		switch columns[0] {
		case "type":
			v.fieldType = columns[1]
		case "desc":
			v.fieldDesc = columns[1]
		case "length":
			lengths := strings.Split(columns[1], "-")
			if len(lengths) != 2 {
				continue
			}

			lt, _ := strconv.ParseInt(lengths[0], 10, 64)
			gt, _ := strconv.ParseInt(lengths[1], 10, 64)
			v.fieldLengthGt = int(gt)
			v.fieldLengthLt = int(lt)
		}

		if columns[0] == "desc" {
			v.fieldDesc = columns[1]
			continue
		}
	}

	return v
}

// Validate 校验
func (v *Validator) Validate(fieldValue interface{}) error {
	// 判断字段类型
	switch fieldValue.(type) {
	case string:
		if v.fieldType == "mobile" {
			return v.validateMobile(fieldValue.(string))
		}
		return v.validateStringLength(fieldValue.(string))
	case int64, uint64, int32, uint32:
		// todo
	}

	return nil
}

func (v *Validator) validateStringLength(fieldValue string) error {

	if len([]rune(fieldValue)) > v.fieldLengthGt {
		return errors.New(v.fieldDesc + "长度超出最大值-" + fieldValue)
	}

	if len([]rune(fieldValue)) < v.fieldLengthLt {
		return errors.New(v.fieldDesc + "长度低于最小值-" + fieldValue)
	}

	return nil
}

const mobileReg = "^1[3456789]\\d{9}$"

func (v *Validator) validateMobile(fieldValue string) error {
	regM := mobileReg
	pattern := regexp.MustCompile(regM)
	if !pattern.MatchString(fieldValue) {
		return errors.New(v.fieldDesc + "格式不正确")
	}

	return nil
}

// Validate 参数校验
func (x *CreateUserRequest) Validate() error {
	validator0 := NewValidator(`// desc:姓名 length:2-20`)
	if err := validator0.Validate(x.Name); err != nil {
		return err
	}

	validator1 := NewValidator(`// desc:手机号码 type:mobile`)
	if err := validator1.Validate(x.Mobile); err != nil {
		return err
	}

	return nil
}

编写服务端文件 #

新建server目录,并在server目录下新建一个server.go文件,内容如下:

server.go
package main

import (
	"bufdemo/pb"
	"context"
	"google.golang.org/genproto/googleapis/rpc/code"
	"google.golang.org/grpc"
	codes "google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"log"
	"net"
)

type UserServiceImpl struct {
	pb.UnimplementedUserServiceServer
}

func (u *UserServiceImpl) CreateUser(ctx context.Context, request *pb.CreateUserRequest) (*pb.CreateUserResponse, error) {
	resp := &pb.CreateUserResponse{}

	if err := request.Validate(); err != nil {
		return nil, status.Error(codes.Code(code.Code_INVALID_ARGUMENT), err.Error())
	}

	resp.Id = 1
	return resp, nil
}

func main() {
	lis, err := net.Listen("tcp", ":8091")

	if err != nil {
		log.Fatalf("failed to listen:%v", err)
	}

	s := grpc.NewServer()

	pb.RegisterUserServiceServer(s, &UserServiceImpl{})

	if err = s.Serve(lis); err != nil {
		log.Fatalf("failed to serve:%v", err)
	}
}
在上面的方法中,我们调用了Validate方法,对入参进行了校验。

运行服务端:

go run server/server.go

使用postman测试 #

我们这里使用postman来测试一下:

  • 测试1:姓名长度不正确:
    测试1
    测试1
  • 测试2:手机号码格式不正确:
    测试2
    测试2

可以看到,我们的校验器已经生效了。