/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace apache::thrift::test { using namespace ::testing; using MethodMetadata = AsyncProcessorFactory::MethodMetadata; using MethodMetadataMap = AsyncProcessorFactory::MethodMetadataMap; using WildcardMethodMetadataMap = AsyncProcessorFactory::WildcardMethodMetadataMap; using CreateMethodMetadataResult = AsyncProcessorFactory::CreateMethodMetadataResult; namespace { class FirstHandler : public apache::thrift::ServiceHandler { int one() override { return 1; } int two() override { return 2; } }; class SecondHandler : public apache::thrift::ServiceHandler { int three() override { return 3; } int four() override { return 4; } }; class ThirdHandler : public apache::thrift::ServiceHandler { int five() override { return 5; } int six() override { return 6; } }; class ConflictsHandler : public apache::thrift::ServiceHandler { int four() override { return 444; } int five() override { return 555; } }; class MultiplexAsyncProcessorTest : public Test { public: std::shared_ptr multiplex( std::vector> services) { return std::make_shared( std::move(services)); } }; class MultiplexAsyncProcessorServerTest : public MultiplexAsyncProcessorTest { public: std::unique_ptr runMultiplexedServices( std::vector> services) { return std::make_unique( multiplex(std::move(services))); } }; } // namespace TEST_F(MultiplexAsyncProcessorTest, getServiceHandlers) { std::vector> services = { std::make_shared(), std::make_shared(), std::make_shared(), std::make_shared(), }; auto processorFactory = std::make_shared(std::move(services)); // Generated service handlers are one per service EXPECT_EQ(processorFactory->getServiceHandlers().size(), 4); } TEST_F(MultiplexAsyncProcessorTest, getServiceHandlers_Nested) { std::vector> services2 = { std::make_shared(), multiplex({ std::make_shared(), std::make_shared(), }), std::make_shared(), }; auto processorFactory = std::make_shared(std::move(services2)); // Generated service handlers are one per service EXPECT_EQ(processorFactory->getServiceHandlers().size(), 4); } TEST_F(MultiplexAsyncProcessorTest, getServiceMetadata) { auto getMetadataFromService = [](AsyncProcessorFactory& service) { metadata::ThriftServiceMetadataResponse response; service.getProcessor()->getServiceMetadata(response); return response; }; std::vector> servicesToMultiplex = { std::make_shared(), std::make_shared(), std::make_shared>(), std::make_shared(), std::make_shared(), }; auto processorFactory = std::make_shared( std::move(servicesToMultiplex)); auto response = getMetadataFromService(*processorFactory); LOG(INFO) << "ServiceMetadata: " << debugString(response); EXPECT_EQ( *response.context_ref()->service_info_ref()->name_ref(), "MultiplexAsyncProcessor.First"); auto& services = *response.services_ref(); EXPECT_EQ(services.size(), 6); EXPECT_EQ(*services[0].service_name_ref(), "MultiplexAsyncProcessor.First"); EXPECT_EQ(*services[1].service_name_ref(), "MultiplexAsyncProcessor.Second"); EXPECT_EQ( *services[2].service_name_ref(), "MultiplexAsyncProcessor.SomeService"); // Base service of SomeService EXPECT_EQ(*services[3].service_name_ref(), "MultiplexAsyncProcessor.Third"); EXPECT_EQ( *services[4].service_name_ref(), "MultiplexAsyncProcessor.Conflicts"); EXPECT_EQ(*services[5].service_name_ref(), "MultiplexAsyncProcessor.Third"); const auto& metadata = *response.metadata_ref(); EXPECT_EQ(metadata.structs_ref()->size(), 1); EXPECT_EQ( metadata.structs_ref()->begin()->first, "MultiplexAsyncProcessor.SomeStruct"); // All composed services are referred to EXPECT_EQ(metadata.services_ref()->size(), 5); } TEST_F(MultiplexAsyncProcessorTest, getServiceMetadata_Nested) { auto getMetadataFromService = [](AsyncProcessorFactory& service) { metadata::ThriftServiceMetadataResponse response; service.getProcessor()->getServiceMetadata(response); return response; }; std::vector> servicesToMultiplex = { std::make_shared(), std::make_shared(), multiplex({ std::make_shared>(), std::make_shared(), }), std::make_shared(), }; auto processorFactory = std::make_shared( std::move(servicesToMultiplex)); auto response = getMetadataFromService(*processorFactory); LOG(INFO) << "ServiceMetadata: " << debugString(response); EXPECT_EQ( *response.context_ref()->service_info_ref()->name_ref(), "MultiplexAsyncProcessor.First"); auto& services = *response.services_ref(); EXPECT_EQ(services.size(), 6); EXPECT_EQ(*services[0].service_name_ref(), "MultiplexAsyncProcessor.First"); EXPECT_EQ(*services[1].service_name_ref(), "MultiplexAsyncProcessor.Second"); EXPECT_EQ( *services[2].service_name_ref(), "MultiplexAsyncProcessor.SomeService"); // Base service of SomeService EXPECT_EQ(*services[3].service_name_ref(), "MultiplexAsyncProcessor.Third"); EXPECT_EQ( *services[4].service_name_ref(), "MultiplexAsyncProcessor.Conflicts"); EXPECT_EQ(*services[5].service_name_ref(), "MultiplexAsyncProcessor.Third"); const auto& metadata = *response.metadata_ref(); EXPECT_EQ(metadata.structs_ref()->size(), 1); EXPECT_EQ( metadata.structs_ref()->begin()->first, "MultiplexAsyncProcessor.SomeStruct"); // All composed services are referred to EXPECT_EQ(metadata.services_ref()->size(), 5); } TEST_F(MultiplexAsyncProcessorServerTest, Basic) { auto runner = runMultiplexedServices( {std::make_shared(), std::make_shared()}); auto client1 = runner->newClient(); auto client2 = runner->newClient(); EXPECT_EQ(client1->semifuture_one().get(), 1); EXPECT_EQ(client1->semifuture_two().get(), 2); EXPECT_EQ(client2->semifuture_three().get(), 3); EXPECT_EQ(client2->semifuture_four().get(), 4); } TEST_F(MultiplexAsyncProcessorServerTest, ConflictPrecedence) { auto runner = runMultiplexedServices( {std::make_shared(), std::make_shared(), std::make_shared()}); auto client2 = runner->newClient(); auto client3 = runner->newClient(); EXPECT_EQ(client2->semifuture_three().get(), 3); // Second takes precedence EXPECT_EQ(client2->semifuture_four().get(), 4); // Conflicts takes precedence EXPECT_EQ(client3->semifuture_five().get(), 555); EXPECT_EQ(client3->semifuture_six().get(), 6); } TEST_F(MultiplexAsyncProcessorServerTest, Nested_1) { auto runner = runMultiplexedServices( {multiplex( {std::make_shared(), std::make_shared()}), std::make_shared()}); auto client2 = runner->newClient(); auto client3 = runner->newClient(); EXPECT_EQ(client2->semifuture_three().get(), 3); // Second takes precedence EXPECT_EQ(client2->semifuture_four().get(), 4); // Conflicts takes precedence EXPECT_EQ(client3->semifuture_five().get(), 555); EXPECT_EQ(client3->semifuture_six().get(), 6); } TEST_F(MultiplexAsyncProcessorServerTest, Nested_2) { auto runner = runMultiplexedServices( {std::make_shared(), multiplex( {std::make_shared(), std::make_shared()}), std::make_shared()}); auto client1 = runner->newClient(); auto client2 = runner->newClient(); auto client3 = runner->newClient(); EXPECT_EQ(client1->semifuture_one().get(), 1); EXPECT_EQ(client2->semifuture_three().get(), 3); // Conflict takes precedence EXPECT_EQ(client2->semifuture_four().get(), 444); // Third takes precedence EXPECT_EQ(client3->semifuture_five().get(), 5); EXPECT_EQ(client3->semifuture_six().get(), 6); } namespace { class ContextData : public folly::RequestData { public: static const folly::RequestToken& getRequestToken() { static folly::RequestToken token( "MultiplexAsyncProcessorTest - ContextData"); return token; } explicit ContextData(int data) : data_(data) {} int data() const { return data_; } bool hasCallback() override { return false; } static int readFromCurrent() { return readFrom(*folly::RequestContext::get()); } private: static int readFrom(const folly::RequestContext& ctx) { auto ctxData = dynamic_cast(ctx.getContextData(getRequestToken())); CHECK(ctxData != nullptr); return ctxData->data(); } int data_; }; struct FromCurrentContextData {}; /** * AsyncProcessorFactory where WildcardMethodMetadata always causes an internal * error with the provided message (or optionally read it from * folly::RequestContext). */ template class WildcardThrowsInternalError : public TProcessorFactory { public: explicit WildcardThrowsInternalError(FromCurrentContextData) : message_{FromCurrentContextData{}} {} explicit WildcardThrowsInternalError(std::string message) : message_{std::move(message)} {} private: using MessageVariant = std::variant; CreateMethodMetadataResult createMethodMetadata() override { auto metadataResult = TProcessorFactory::createMethodMetadata(); return folly::variant_match( metadataResult, [](MethodMetadataMap& knownMethods) -> WildcardMethodMetadataMap { return WildcardMethodMetadataMap{ std::make_shared< const AsyncProcessorFactory::WildcardMethodMetadata>(), std::move(knownMethods)}; }, [](WildcardMethodMetadataMap& map) -> WildcardMethodMetadataMap { return std::move(map); }); } std::unique_ptr getProcessor() override { class Processor : public AsyncProcessor { public: void processSerializedCompressedRequestWithMetadata( ResponseChannelRequest::UniquePtr req, SerializedCompressedRequest&& serializedRequest, const MethodMetadata& untypedMethodMetadata, protocol::PROTOCOL_TYPES protocolType, Cpp2RequestContext* context, folly::EventBase* eb, concurrency::ThreadManager* tm) override { if (untypedMethodMetadata.isWildcard()) { std::string message = folly::variant_match( message_, [](const std::string& m) { return m; }, [](FromCurrentContextData) { return folly::to(ContextData::readFromCurrent()); }); req->sendErrorWrapped( folly::make_exception_wrapper( TApplicationException::INTERNAL_ERROR, std::move(message)), "" /* errorCode */); return; } inner_->processSerializedCompressedRequestWithMetadata( std::move(req), std::move(serializedRequest), untypedMethodMetadata, protocolType, context, eb, tm); } void executeRequest( ServerRequest&& request, const AsyncProcessorFactory::MethodMetadata& methodMetadata) override { if (methodMetadata.isWildcard()) { std::string message = folly::variant_match( message_, [](const std::string& m) { return m; }, [](FromCurrentContextData) { return folly::to(ContextData::readFromCurrent()); }); auto eb = detail::ServerRequestHelper::eventBase(request); eb->runInEventBaseThread([request = std::move(request), message = std::move(message)]() mutable { request.request()->sendErrorWrapped( folly::make_exception_wrapper( TApplicationException::INTERNAL_ERROR, std::move(message)), "" /* errorCode */); }); return; } inner_->executeRequest(std::move(request), methodMetadata); } void terminateInteraction( int64_t id, Cpp2ConnContext& ctx, folly::EventBase& eb) noexcept override { inner_->terminateInteraction(id, ctx, eb); } void destroyAllInteractions( Cpp2ConnContext& ctx, folly::EventBase& eb) noexcept override { inner_->destroyAllInteractions(ctx, eb); } explicit Processor( std::unique_ptr&& inner, const MessageVariant& message) : inner_(std::move(inner)), message_(message) {} private: std::unique_ptr inner_; const MessageVariant& message_; }; return std::make_unique( TProcessorFactory::getProcessor(), message_); } MessageVariant message_; }; } // namespace TEST_F(MultiplexAsyncProcessorServerTest, BasicWildcard) { auto runner = runMultiplexedServices( {std::make_shared(), std::make_shared>( "BasicWildcard")}); auto client1 = runner->newClient(); auto client2 = runner->newClient(); auto client3 = runner->newClient(); EXPECT_EQ(client1->semifuture_one().get(), 1); EXPECT_EQ(client1->semifuture_two().get(), 2); EXPECT_EQ(client2->semifuture_three().get(), 3); EXPECT_EQ(client2->semifuture_four().get(), 4); EXPECT_THAT( [&] { client3->semifuture_five().get(); }, ThrowsMessage("BasicWildcard")); } TEST_F(MultiplexAsyncProcessorServerTest, WildcardSwallows) { auto runner = runMultiplexedServices( {std::make_shared>( "WildcardSwallows"), std::make_shared(), std::make_shared>( "NeverReached")}); auto client1 = runner->newClient(); auto client2 = runner->newClient(); EXPECT_EQ(client1->semifuture_one().get(), 1); EXPECT_EQ(client1->semifuture_two().get(), 2); EXPECT_THAT( [&] { client2->semifuture_three().get(); }, ThrowsMessage("WildcardSwallows")); } TEST_F(MultiplexAsyncProcessorServerTest, WildcardConflicts) { auto runner = runMultiplexedServices( {std::make_shared(), std::make_shared>( "WildcardConflicts")}); auto client2 = runner->newClient(); auto client3 = runner->newClient(); // Known methods takes precedence EXPECT_EQ(client2->semifuture_three().get(), 3); EXPECT_EQ(client2->semifuture_four().get(), 4); EXPECT_EQ(client3->semifuture_five().get(), 555); EXPECT_THAT( [&] { client3->semifuture_six().get(); }, ThrowsMessage("WildcardConflicts")); } namespace { class RctxFirst : public apache::thrift::ServiceHandler { int one() override { return ContextData::readFromCurrent(); } int two() override { return ContextData::readFromCurrent(); } }; class RctxSecond : public apache::thrift::ServiceHandler { int three() override { return ContextData::readFromCurrent(); } int four() override { return ContextData::readFromCurrent(); } }; class RctxThird : public apache::thrift::ServiceHandler { int five() override { return ContextData::readFromCurrent(); } int six() override { return ContextData::readFromCurrent(); } }; class RctxConflicts : public apache::thrift::ServiceHandler { int four() override { return ContextData::readFromCurrent(); } int five() override { return ContextData::readFromCurrent(); } }; template class WithRequestContextData : public TProcessorFactory { public: using TProcessorFactory::TProcessorFactory; std::shared_ptr getBaseContextForRequest( const MethodMetadata&) override { auto ctx = std::make_shared(); ctx->setContextData( ContextData::getRequestToken(), std::make_unique(kData)); return ctx; } }; } // namespace TEST_F(MultiplexAsyncProcessorServerTest, RequestContext) { auto runner = runMultiplexedServices( {std::make_shared>(), std::make_shared>()}); auto client1 = runner->newClient(); auto client2 = runner->newClient(); EXPECT_EQ(client1->semifuture_one().get(), 1); EXPECT_EQ(client1->semifuture_two().get(), 1); EXPECT_EQ(client2->semifuture_three().get(), 2); EXPECT_EQ(client2->semifuture_four().get(), 2); } TEST_F(MultiplexAsyncProcessorServerTest, RequestContextWildcard) { auto runner = runMultiplexedServices( {std::make_shared< WithRequestContextData, 1>>( FromCurrentContextData{}), std::make_shared>()}); auto client1 = runner->newClient(); auto client2 = runner->newClient(); EXPECT_EQ(client1->semifuture_one().get(), 1); EXPECT_EQ(client1->semifuture_two().get(), 1); EXPECT_THAT( [&] { client2->semifuture_three().get(); }, ThrowsMessage("1")); } TEST_F(MultiplexAsyncProcessorServerTest, RequestContextConflictPrecedence) { auto runner = runMultiplexedServices( {std::make_shared>(), std::make_shared>(), std::make_shared>()}); auto client2 = runner->newClient(); auto client3 = runner->newClient(); EXPECT_EQ(client2->semifuture_three().get(), 2); // Second takes precedence EXPECT_EQ(client2->semifuture_four().get(), 2); // Conflicts takes precedence EXPECT_EQ(client3->semifuture_five().get(), -1); EXPECT_EQ(client3->semifuture_six().get(), 3); } namespace { RequestChannel::Ptr makeRocketChannel(folly::AsyncSocket::UniquePtr socket) { return RocketClientChannel::newChannel(std::move(socket)); } } // namespace TEST_F(MultiplexAsyncProcessorServerTest, Interaction) { using Counter = std::atomic; class TerminateInteractionTrackingProcessor : public AsyncProcessor { public: void processSerializedCompressedRequestWithMetadata( ResponseChannelRequest::UniquePtr req, SerializedCompressedRequest&& serializedRequest, const MethodMetadata& methodMetadata, protocol::PROTOCOL_TYPES protocolType, Cpp2RequestContext* context, folly::EventBase* eb, concurrency::ThreadManager* tm) override { delegate_->processSerializedCompressedRequestWithMetadata( std::move(req), std::move(serializedRequest), methodMetadata, protocolType, context, eb, tm); } void executeRequest( ServerRequest&& request, const AsyncProcessorFactory::MethodMetadata& methodMetadata) override { delegate_->executeRequest(std::move(request), methodMetadata); } virtual void terminateInteraction( int64_t id, Cpp2ConnContext& ctx, folly::EventBase& eb) noexcept override { ++numCalls_; delegate_->terminateInteraction(id, ctx, eb); } void processInteraction(ServerRequest&& request) override { delegate_->processInteraction(std::move(request)); } explicit TerminateInteractionTrackingProcessor( std::unique_ptr&& delegate, Counter& numCalls) : delegate_(std::move(delegate)), numCalls_(numCalls) {} private: std::unique_ptr delegate_; Counter& numCalls_; }; class Interaction1Handler : public apache::thrift::ServiceHandler { public: std::unique_ptr createThing1() override { class Thing1 : public Thing1If { public: void foo() override { ++numCalls_; } explicit Thing1(Counter& numCalls, folly::Baton<>& destroyed) : numCalls_(numCalls), destroyed_(destroyed) {} ~Thing1() override { destroyed_.post(); } private: Counter& numCalls_; folly::Baton<>& destroyed_; }; return std::make_unique(numCalls, destroyed); } std::unique_ptr getProcessor() override { return std::make_unique( apache::thrift::ServiceHandler::getProcessor(), numTerminateInteractionCalls); } Counter numCalls{0}; Counter numTerminateInteractionCalls{0}; folly::Baton<> destroyed; }; class Interaction2Handler : public apache::thrift::ServiceHandler { std::unique_ptr createThing2() override { class Thing2 : public Thing2If { public: void bar() override { ++numCalls_; } explicit Thing2(Counter& numCalls) : numCalls_(numCalls) {} private: Counter& numCalls_; }; return std::make_unique(numCalls); } public: Counter numCalls{0}; }; auto interaction1 = std::make_shared(); auto interaction2 = std::make_shared(); auto runner = runMultiplexedServices({interaction1, interaction2}); auto client1 = runner->newClient( nullptr /* callbackExecutor */, makeRocketChannel); auto client2 = runner->newClient( nullptr /* callbackExecutor */, makeRocketChannel); std::optional thing1 = client1->createThing1(); thing1->semifuture_foo().get(); std::optional thing2 = client2->createThing2(); thing2->semifuture_bar().get(); EXPECT_EQ(interaction1->numCalls.load(), 1); EXPECT_EQ(interaction2->numCalls.load(), 1); // Make sure interaction gets destroyed thing1.reset(); interaction1->destroyed.wait(); EXPECT_EQ(interaction1->numTerminateInteractionCalls.load(), 1); // Other interactions should not be destroyed thing2->semifuture_bar().get(); EXPECT_EQ(interaction2->numCalls.load(), 2); } TEST_F(MultiplexAsyncProcessorServerTest, InteractionConflict) { class Interaction1Handler : public apache::thrift::ServiceHandler { public: std::unique_ptr createThing1() override { class Thing1 : public Thing1If { public: void foo() override {} }; return std::make_unique(); } }; class ConflictsInteraction1Handler : public apache::thrift::ServiceHandler< apache::thrift::test2::ConflictsInteraction1> { public: std::unique_ptr createThing1() override { class Thing1 : public Thing1If { public: void foo() override { ADD_FAILURE() << "Should never be called"; } void bar() override { ADD_FAILURE() << "Should never be called"; } }; return std::make_unique(); } }; auto runner = runMultiplexedServices( {std::make_shared(), std::make_shared(), std::make_shared>( "ConflictsInteraction1")}); auto client1 = runner->newClient( nullptr /* callbackExecutor */, makeRocketChannel); auto client2 = runner ->newClient( nullptr /* callbackExecutor */, makeRocketChannel); auto thing = client1->createThing1(); thing.semifuture_foo().get(); auto thing2 = client2->createThing1(); // Thing1.bar from ConflictsInteraction1 should not be in MethodMetadataMap // because Interaction1 already added Thing1.foo. EXPECT_THAT( [&] { thing2.semifuture_bar().get(); }, ThrowsMessage("ConflictsInteraction1")); } TEST_F(MultiplexAsyncProcessorTest, ThriftGenerated) { auto generated = multiplex( {std::make_shared(), std::make_shared(), multiplex({std::make_shared()})}); auto notGenerated = multiplex( {std::make_shared>( std::make_shared()), std::make_shared()}); EXPECT_TRUE(generated->isThriftGenerated()); EXPECT_FALSE(notGenerated->isThriftGenerated()); } } // namespace apache::thrift::test