/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "context" "crypto/sha256" "errors" "fmt" "net" "os" "os/signal" "syscall" "github.com/golang/glog" thrift_any "thrift/conformance/any" "thrift/conformance/conformance" "thrift/conformance/patch_data" "thrift/conformance/protocol" "thrift/conformance/serialization" "thrift/lib/go/thrift" "thrift/lib/thrift/protocol_detail" "thrift/test/testset" enum "thrift/test/testset/Enum" ) // Registry initializer func type type registryInitializerFuncType = func() any // typeRegistry is a registry from thrift_uri or its hash Thrift initializer function. // typeRegistry is used to serialize and deserialize thrift.Any. type typeRegistry struct { names map[string]registryInitializerFuncType hash32 map[[32]byte]registryInitializerFuncType hash16 map[[16]byte]registryInitializerFuncType hash8 map[[8]byte]registryInitializerFuncType } func newTypeRegistry() *typeRegistry { return &typeRegistry{ names: make(map[string]registryInitializerFuncType), hash32: make(map[[32]byte]registryInitializerFuncType), hash16: make(map[[16]byte]registryInitializerFuncType), hash8: make(map[[8]byte]registryInitializerFuncType), } } const thriftURIPrefix = "fbthrift://" // RegisterType is called by the generated RegisterTypes function in thrift packages. // Only types with a thrift_uri is registered. func (r *typeRegistry) RegisterType(name string, initializer registryInitializerFuncType) { r.names[name] = initializer h := sha256.New() h.Write([]byte(thriftURIPrefix + name)) hash := h.Sum(nil) hash32 := *(*[32]byte)(hash[0:32]) hash16 := *(*[16]byte)(hash[0:16]) hash8 := *(*[8]byte)(hash[0:8]) r.hash32[hash32] = initializer r.hash16[hash16] = initializer r.hash8[hash8] = initializer } // LoadInitializerWithName loads initializer from the type registry for deserialization given the thrift_uri name. func (r *typeRegistry) LoadInitializerWithName(name string) (registryInitializerFuncType, error) { initializer, ok := r.names[name] if !ok { return nil, fmt.Errorf("load from registry error: %s is not registered", name) } return initializer, nil } // LoadInitializerWithHash loads initializer from the type registry for deserialization given the hashed thrift_uri name. // The hashed thrift_uri can either be of length 8, 16 or 32. func (r *typeRegistry) LoadInitializerWithHash(hash []byte) (registryInitializerFuncType, error) { var initializer registryInitializerFuncType var ok bool if len(hash) == 8 { hash8 := *(*[8]byte)(hash[0:8]) initializer, ok = r.hash8[hash8] if !ok { return nil, fmt.Errorf("load from hash8 registry error: %s is not registered", string(hash)) } } if len(hash) == 16 { hash16 := *(*[16]byte)(hash[0:16]) initializer, ok = r.hash16[hash16] if !ok { return nil, fmt.Errorf("load from hash16 registry error: %s is not registered", string(hash)) } } if len(hash) == 32 { hash32 := *(*[32]byte)(hash[0:32]) initializer, ok = r.hash32[hash32] if !ok { return nil, fmt.Errorf("load from hash32 registry error: %s is not registered", string(hash)) } } return initializer, nil } func main() { // Catch SIGTERM/SIGKILL sigc := make(chan os.Signal, 1) signal.Notify(sigc, syscall.SIGTERM, syscall.SIGINT, ) // Register all types from the testset registry := newTypeRegistry() testset.RegisterTypes(registry) enum.RegisterTypes(registry) protocol_detail.RegisterTypes(registry) // Startup thrift server handler := &dataConformanceServiceHandler{registry} proc := conformance.NewConformanceServiceProcessor(handler) ts, addr, err := newServer( proc, // Ports must be dynamically allocated to prevent any conflicts. // Allocating a free port is usually done by setting the port number as zero. // Operating system should assign a free port to the application. "[::]:0", ) if err != nil { glog.Fatalf("failed to start server: %v", err) } fmt.Println(addr.(*net.TCPAddr).Port) ctx, cancel := context.WithCancel(context.Background()) go func() { err := ts.ServeContext(ctx) if err != nil { glog.Fatalf("failed to start server") } }() <-sigc cancel() os.Exit(0) } func newServer(processor thrift.Processor, addr string) (thrift.Server, net.Addr, error) { socket, err := thrift.NewServerSocket(addr) if err != nil { return nil, nil, err } return thrift.NewSimpleServer(processor, socket, thrift.TransportIDHeader), socket.Addr(), nil } type dataConformanceServiceHandler struct { registry *typeRegistry } func (h *dataConformanceServiceHandler) RoundTrip(ctx context.Context, roundTripRequest *serialization.RoundTripRequest) (*serialization.RoundTripResponse, error) { requestValue := roundTripRequest.GetValue() if requestValue == nil { return nil, errors.New("unsupported RoundTrip roundTripRequest.Value = nil") } obj, err := deserialize(h.registry, requestValue) if err != nil { return nil, err } target := getTargetProtocol(roundTripRequest) data, err := serialize(obj, target) if err != nil { return nil, err } return newRoundTripResponse(newResponse(requestValue, data)), nil } // newRoundTripResponse wraps the response thrift.Any inside a RoundTripResponse. func newRoundTripResponse(response *thrift_any.Any) *serialization.RoundTripResponse { resp := serialization.NewRoundTripResponse(). SetValue(response) return resp } // newResponse creates a new response Any from the request Any using new serialized data. func newResponse(request *thrift_any.Any, data []byte) *thrift_any.Any { respAny := thrift_any.NewAny(). SetData(data). SetCustomProtocol(request.CustomProtocol). SetProtocol(request.Protocol). SetType(request.Type) return respAny } // serialize serializes a thrift.Struct with a target protocol to be stored inside a thrift.Any. func serialize(obj thrift.Struct, protoc *protocol.ProtocolStruct) ([]byte, error) { s, err := newSerializer(protoc) if err != nil { return nil, err } return s.Write(obj) } // loadStruct loads a thrift.Struct from the typeRegistry for a given thrift.Any. // Any specifies the thrift.Struct to load either with a thrift_uri stored in the Type field // Or with a hashed version of thrift_uri stored in TypeHashPrefixSha2_256. func loadStruct(registry *typeRegistry, value *thrift_any.Any) (thrift.Struct, error) { var initializer registryInitializerFuncType var err error if value.IsSetType() { typ := value.GetType() initializer, err = registry.LoadInitializerWithName(typ) if err != nil { return nil, err } } else if value.IsSetTypeHashPrefixSha2_256() { hash := value.GetTypeHashPrefixSha2_256() initializer, err = registry.LoadInitializerWithHash(hash) if err != nil { return nil, err } } anyObj := initializer() structObj, ok := anyObj.(thrift.Struct) if !ok { return nil, fmt.Errorf("deserialize currently only supports thrift.Struct and not %T", anyObj) } return structObj, nil } // deserialize deserializes the data stored inside a thrift.Any value. func deserialize(registry *typeRegistry, value *thrift_any.Any) (thrift.Struct, error) { obj, err := loadStruct(registry, value) if err != nil { return nil, err } protoc := getProtocol(value) d, err := newDeserializer(protoc) if err != nil { return nil, err } if err := d.Read(obj, value.GetData()); err != nil { return nil, err } return obj, nil } // newSerializer initializes the appropriate serializer for the specific protocol. func newSerializer(protoc *protocol.ProtocolStruct) (*thrift.Serializer, error) { switch protoc.GetStandard() { case protocol.StandardProtocol_Custom: case protocol.StandardProtocol_Binary: return thrift.NewSerializer(), nil case protocol.StandardProtocol_Compact: return thrift.NewCompactSerializer(), nil case protocol.StandardProtocol_Json: return thrift.NewSimpleJSONSerializer(), nil case protocol.StandardProtocol_SimpleJson: return thrift.NewJSONSerializer(), nil } // default value in case the protocol is unknown, as seen in the java implementation of conformance tests. return thrift.NewCompactSerializer(), nil } // newDeserializer initializes the appropriate deserializer for the specific protocol. func newDeserializer(protoc *protocol.ProtocolStruct) (*thrift.Deserializer, error) { switch protoc.GetStandard() { case protocol.StandardProtocol_Custom: case protocol.StandardProtocol_Binary: return thrift.NewDeserializer(), nil case protocol.StandardProtocol_Compact: return thrift.NewCompactDeserializer(), nil case protocol.StandardProtocol_Json: return thrift.NewSimpleJSONDeserializer(), nil case protocol.StandardProtocol_SimpleJson: return thrift.NewJSONDeserializer(), nil } // default value in case the protocol is unknown, as seen in the java implementation of conformance tests. return thrift.NewCompactDeserializer(), nil } // getTargetProtocol returns a consistent target protocol in the ProtocolStruct, whether the target protocol was set or not. // In the case that the target protocol is taken from the value set inside the Any value. // Worst case the Compact protocol is returned as a default. // "Any request and encodes it back to the RoundTripResponse using target protocol - if not empty - or using the protocol in Any." func getTargetProtocol(request *serialization.RoundTripRequest) *protocol.ProtocolStruct { if request.IsSetTargetProtocol() { return request.GetTargetProtocol() } // default value in case the protocol is not specified, as seen in the java implementation of conformance tests. if !request.IsSetValue() { return &protocol.ProtocolStruct{ Standard: protocol.StandardProtocol_Compact, Custom: nil, } } value := request.GetValue() return getProtocol(value) } // getProtocol creates a ProtocolStruct from the values stored inside thrift.Any. func getProtocol(value *thrift_any.Any) *protocol.ProtocolStruct { return &protocol.ProtocolStruct{ Standard: value.GetProtocol(), Custom: value.CustomProtocol, } } func (h *dataConformanceServiceHandler) Patch(ctx context.Context, request *patch_data.PatchOpRequest) (_r *patch_data.PatchOpResponse, err error) { return nil, errors.New("patch is not supported") }