拦截器包含服务端拦截器客户端拦截器

原理:

遍历 ops.interceptors数组,以.interceptors[0]为头,然后传入将数组后面的interceptors依次递归写到下一个handler

实现方式:

  1. option设置了数组连接器,可以向数组拦截器中添加自己的自定义拦截器!
  2. 但grpc使用拦截器,最终使用的是将数组拦截器串起来的链式拦截器。

实现方式:

● 判断拦截器数组长度
● 若大于1,将拦截器数组合并成一个链式拦截器
重点:本拦截器处理完,将handler设置成数组的下一个拦截器!

层层套娃的结构
如:
function1(){
function2(){
function3(){
function4(){}
}
}
}

  1. // chainUnaryClientInterceptors chains all unary client interceptors into one.
  2. func chainUnaryClientInterceptors(cc *ClientConn) {
  3. interceptors := cc.dopts.chainUnaryInts
  4. // Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
  5. // be executed before any other chained interceptors.
  6. if cc.dopts.unaryInt != nil {
  7. interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
  8. }
  9. var chainedInt UnaryClientInterceptor
  10. if len(interceptors) == 0 {
  11. chainedInt = nil
  12. } else if len(interceptors) == 1 {
  13. chainedInt = interceptors[0]
  14. } else {
  15. chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
  16. return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
  17. }
  18. }
  19. cc.dopts.unaryInt = chainedInt
  20. }
  21. func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker {
  22. if curr == len(interceptors)-1 {
  23. return finalInvoker
  24. }
  25. return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
  26. return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...)
  27. }
  28. }

使用方法:

客户端:

    // Set up a connection to the server.
    conn, err := grpc.Dial(*addr, grpc.WithTransportCredentials(creds),
        grpc.WithUnaryInterceptor(unaryInterceptor),
        grpc.WithStreamInterceptor(streamInterceptor),
        grpc.WithBlock())
    if err != nil {
        log.Fatalf("did not connect: %v", err)
    }

服务端

// 服务端使用拦截器
s := grpc.NewServer(
    grpc.Creds(creds),
    grpc.UnaryInterceptor(unaryInterceptor),
    grpc.StreamInterceptor(streamInterceptor))