grpc 官方提供了对 oauth2 认证鉴权的实现 demo,放在 examples 目录的 features 目录的 authentication 目录下,我们来看一下源码实现

    server

    server 端源码实现如下:

    server 端先调用了 tls 包下的 LoadX509KeyPair,通过 server 的公钥和私钥生成了一个 Certificate 结构体来保存证书信息。然后注册了一个校验 token 的方法到拦截器中,并将证书信息设置到 serverOption 中,构造 server 的时候层层透传进去,最终会被设置到 Server 里面 ServerOptions 结构中的 credentials.TransportCredentials 和 UnaryServerInterceptor 中。

    我们来看看这两个结构什么时候会被调用,先梳理调用链路,在 s.Serve ——> s.handleRawConn ——> s.serveStreams ——> s.handleStream ——> s.processUnaryRPC 方法中有一行

    可以看到调用了 md.Handler 方法,将 s.opts.unaryInt 这个结构传入了进去。s.opts.unaryInt 就是我们之前注册的 UnaryServerInterceptor 拦截器。md 是一个 MethodDesc 这个结构,包括了 MethodName 和 Handler

    1. type MethodDesc struct {
    2. MethodName string
    3. Handler methodHandler
    4. }

    这里会取出我们之前注册进去的结构,还记得我们介绍 helloworld 时 RegisterService 吗?至于如何取出 MethodName,源码中的设计非常复杂,经过了层层包装,这里不是本节重点就不赘述了。

    1. func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
    2. s.RegisterService(&_Greeter_serviceDesc, srv)
    3. }
    4. var _Greeter_serviceDesc = grpc.ServiceDesc{
    5. ServiceName: "helloworld.Greeter",
    6. HandlerType: (*GreeterServer)(nil),
    7. Methods: []grpc.MethodDesc{
    8. {
    9. MethodName: "SayHello",
    10. Handler: _Greeter_SayHello_Handler,
    11. },
    12. },
    13. Streams: []grpc.StreamDesc{},
    14. Metadata: "helloworld.proto",
    15. }

    我们看到 md.Handler 其实是 _Greeter_SayHello_Handler 这个结构,它也是在 pb 文件中生成的。

    1. func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
    2. in := new(HelloRequest)
    3. if err := dec(in); err != nil {
    4. return nil, err
    5. }
    6. if interceptor == nil {
    7. return srv.(GreeterServer).SayHello(ctx, in)
    8. }
    9. info := &grpc.UnaryServerInfo{
    10. Server: srv,
    11. FullMethod: "/helloworld.Greeter/SayHello",
    12. }
    13. handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    14. return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
    15. }
    16. return interceptor(ctx, in, info, handler)
    17. }

    这里调用了我们传入的 interceptor 方法。回到我们的调用:

    1. reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
    1. opts := []grpc.ServerOption{
    2. // The following grpc.ServerOption adds an interceptor for all unary
    3. // RPCs. To configure an interceptor for streaming RPCs, see:
    4. grpc.UnaryInterceptor(ensureValidToken),
    5. // Enable TLS for all incoming connections.
    6. }
    7. s := grpc.NewServer(opts...)

    看 grpc.UnaryInterceptor 这个方法,其实是将 ensureValidToken 这个函数赋值给了 s.opts.unaryInt

    所以之前我们执行的这一行

    1. return interceptor(ctx, in, info, handler)

    其实是执行了 ensureValidToken 这个函数,这个函数就是我们在 server 端定义的 token 校验的函数。先取出我们传入的 metadata 数据,然后校验 token

    1. func ensureValidToken(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    2. md, ok := metadata.FromIncomingContext(ctx)
    3. if !ok {
    4. return nil, errMissingMetadata
    5. }
    6. // The keys within metadata.MD are normalized to lowercase.
    7. // See: https://godoc.org/google.golang.org/grpc/metadata#New
    8. if !valid(md["authorization"]) {
    9. return nil, errInvalidToken
    10. }
    11. // Continue execution of handler after ensuring a valid token.
    12. return handler(ctx, req)
    13. }

    校验完 token 后,最终执行了 handler(ctx, req)

    1. handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    2. return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
    3. }
    4. return interceptor(ctx, in, info, handler)

    可以看到最终其实执行了 GreeterServer 的 SayHello 这个函数,也就是我们在 main 函数中定义的,这个函数就是我们在 server 端定义的提供 SayHello 给客户端回消息的函数。

    1. // SayHello implements helloworld.GreeterServer
    2. func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
    3. log.Printf("Received: %v", in.Name)
    4. return &pb.HelloReply{Message: "Hello " + in.Name}, nil
    5. }

    这里还可以额外说一下,md.Handler 执行完之后,其实 reply 就是 SayHello 的回包。

    1. reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)

    获取到回包之后 server 执行了 sendResponse 方法,将回包发送给 client,这个方法我们之前已经剖析过了,最终会调用 http2Server 的 Write 方法。

    1. if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {

    看到这里,server 端对 token 的校验在哪里执行的我们已经清楚了。假如还没有被绕晕,那么恭喜你!可以继续完成 client 的挑战了。

    client

    可以看到 client 首先通过 NewOauthAccess 方法生成了包含 token 信息的 PerRPCCredentials 结构

    1. func NewOauthAccess(token *oauth2.Token) credentials.PerRPCCredentials {
    2. return oauthAccess{token: *token}
    3. }

    然后再将 PerRPCCredentials 通过 grpc.WithPerRPCCredentials(perRPC) 添加到了到了 client 的 DialOptions 中的 transport.ConnectOptions 结构中的 [] credentials.PerRPCCredentials 结构中。

    那么这个结构什么时候被使用呢,我们来看看。先梳理下调用链 ,在 client 调用的 Invoke ——> invoke ——> newClientStream ——> cs.newAttemptLocked ——> cs.cc.getTransport ——> pick ——> acw.getAddrConn().getReadyTransport() ——> ac.connect() ——> ac.resetTransport() ——> ac.tryAllAddrs ——> ac.createTransport ——> transport.NewClientTransport ——> newHTTP2Client 这个方法里面,有这么一段代码,先取出 []credentials.PerRPCCredentials 中的所有 PerRPCCredentials 添加到 perRPCCreds 中。

    1. transportCreds := opts.TransportCredentials
    2. perRPCCreds := opts.PerRPCCredentials
    3. if b := opts.CredsBundle; b != nil {
    4. if t := b.TransportCredentials(); t != nil {
    5. transportCreds = t
    6. }
    7. if t := b.PerRPCCredentials(); t != nil {
    8. perRPCCreds = append(perRPCCreds, t)
    9. }
    10. }

    然后再将 perRPCCreds 赋值给 http2Client 的 perRPCCreds 属性

    1. t := &http2Client{
    2. ...
    3. perRPCCreds: perRPCCreds,
    4. ...
    5. }

    那么 perRPCCreds 属性什么时候被用呢?来继续跟踪,newClientStream 方法中有一段代码

    1. op := func(a *csAttempt) error { return a.newStream() }
    2. if err := cs.withRetry(op, func() { cs.bufferForRetryLocked(0, op) }); err != nil {
    3. cs.finish(err)
    4. return nil, err
    5. }

    这里调用了 csAttempt 的 newStream ——> a.t.NewStream (http2Client 的 NewStream) ——> createHeaderFields ——> getTrAuthData 方法

    1. func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) {
    2. if len(t.perRPCCreds) == 0 {
    3. return nil, nil
    4. }
    5. authData := map[string]string{}
    6. for _, c := range t.perRPCCreds {
    7. data, err := c.GetRequestMetadata(ctx, audience)
    8. if err != nil {
    9. if _, ok := status.FromError(err); ok {
    10. return nil, err
    11. }
    12. return nil, status.Errorf(codes.Unauthenticated, "transport: %v", err)
    13. }
    14. for k, v := range data {
    15. // Capital header names are illegal in HTTP/2.
    16. k = strings.ToLower(k)
    17. authData[k] = v
    18. }
    19. }
    20. return authData, nil
    21. }

    这个方法,通过调用 GetRequestMetadata 取出 token 信息,这里会调用 oauth 的 GetRequestMetadata 方法 ,按照指定格式拼装成一个 map[string]string{} 的形式

    1. func (s *serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
    2. s.mu.Lock()
    3. defer s.mu.Unlock()
    4. if !s.t.Valid() {
    5. var err error
    6. s.t, err = s.config.TokenSource(ctx).Token()
    7. if err != nil {
    8. return nil, err
    9. }
    10. }
    11. return map[string]string{
    12. "authorization": s.t.Type() + " " + s.t.AccessToken,

    然后将以 map[string]string{} 的形式组装成一个 string map 返回,如下:

    至此,client 和 server 的数据流转过程被打通