/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace apache::thrift; using namespace ::testing; namespace { using TransportType = Cpp2ConnContext::TransportType; std::unique_ptr createHTTP2RoutingHandler( ThriftServer& server) { auto h2Options = std::make_unique(); h2Options->threads = static_cast(server.getNumIOWorkerThreads()); h2Options->idleTimeout = server.getIdleTimeout(); h2Options->shutdownOn = {SIGINT, SIGTERM}; return std::make_unique( std::move(h2Options), server.getThriftProcessor(), server); } class ServiceInterceptorTestP : public ::testing::TestWithParam { public: TransportType transportType() const { return GetParam(); } std::unique_ptr makeServer( std::shared_ptr service, ScopedServerInterfaceThread::ServerConfigCb configureServer = {}) { auto runner = std::make_unique( std::move(service), std::move(configureServer)); if (transportType() == TransportType::HTTP2) { auto& thriftServer = runner->getThriftServer(); thriftServer.addRoutingHandler(createHTTP2RoutingHandler(thriftServer)); } return runner; } ScopedServerInterfaceThread::MakeChannelFunc channelFor( TransportType transportType) { return [transportType]( folly::AsyncSocket::UniquePtr socket) -> RequestChannel::Ptr { switch (transportType) { case TransportType::HEADER: return HeaderClientChannel::newChannel( HeaderClientChannel::WithoutRocketUpgrade{}, std::move(socket)); case TransportType::ROCKET: return RocketClientChannel::newChannel(std::move(socket)); case TransportType::HTTP2: { auto channel = HTTPClientChannel::newHTTP2Channel(std::move(socket)); channel->setProtocolId(protocol::T_COMPACT_PROTOCOL); return channel; } default: throw std::logic_error{"Unreachable!"}; } }; } template std::unique_ptr makeClient(ScopedServerInterfaceThread& runner) { return runner.newClient(nullptr, channelFor(transportType())); } }; struct TestHandler : apache::thrift::ServiceHandler { folly::coro::Task co_noop() override { co_return; } folly::coro::Task> co_echo( std::unique_ptr str) override { if (*str == "throw") { throw std::runtime_error("You asked for it!"); } co_return std::move(str); } folly::coro::Task> co_requestArgs( std::int32_t, std::unique_ptr, std::unique_ptr) override { co_return std::make_unique("return value"); } folly::coro::Task> co_echoStruct( std::unique_ptr request) override { auto result = std::make_unique(); result->foo() = std::move(*request->foo()); result->bar() = std::move(*request->bar()); co_return result; } void async_eb_echo_eb( apache::thrift::HandlerCallbackPtr> callback, std::unique_ptr str) override { callback->result(*str); } folly::coro::Task> co_createInteraction() override { class SampleInteractionImpl : public SampleInteractionIf { folly::coro::Task> co_echo( std::unique_ptr str) override { co_return std::move(str); } }; co_return {std::make_unique()}; } std::unique_ptr createSampleInteraction2() override { class SampleInteraction2Impl : public SampleInteraction2If { folly::coro::Task> co_echo( std::unique_ptr str) override { co_return std::move(str); } }; return std::make_unique(); } apache::thrift::ServerStream sync_iota( std::int32_t start) override { return folly::coro::co_invoke( [current = start]() mutable -> folly::coro::AsyncGenerator { while (true) { co_yield current++; } }); } }; using InterceptorList = std::vector>; class TestModule : public apache::thrift::ServerModule { public: explicit TestModule(std::shared_ptr interceptor) { interceptors_.emplace_back(std::move(interceptor)); } explicit TestModule(InterceptorList interceptors) : interceptors_(std::move(interceptors)) {} std::string getName() const override { return "TestModule"; } InterceptorList getServiceInterceptors() override { return interceptors_; } private: InterceptorList interceptors_; }; template struct NamedServiceInterceptor : public ServiceInterceptor { explicit NamedServiceInterceptor(std::string name) : name_(std::move(name)) {} std::string getName() const override { return name_; } private: std::string name_; }; struct ServiceInterceptorCountWithRequestState : public NamedServiceInterceptor { public: using ConnectionState = int; using RequestState = int; using NamedServiceInterceptor::NamedServiceInterceptor; std::optional onConnection( ConnectionInfo) noexcept override { onConnectionCount++; return 1; } void onConnectionClosed( ConnectionState* connectionState, ConnectionInfo) noexcept override { onConnectionClosedCount += *connectionState; } folly::coro::Task> onRequest( ConnectionState*, RequestInfo) override { onRequestCount++; co_return 1; } folly::coro::Task onResponse( RequestState* requestState, ConnectionState*, ResponseInfo) override { onResponseCount += *requestState; co_return; } int onConnectionCount = 0; int onConnectionClosedCount = 0; int onRequestCount = 0; int onResponseCount = 0; }; struct ServiceInterceptorThrowOnRequest : public NamedServiceInterceptor { public: using NamedServiceInterceptor::NamedServiceInterceptor; folly::coro::Task> onRequest( folly::Unit*, RequestInfo) override { onRequestCount++; throw std::runtime_error( "Exception from ServiceInterceptorThrowOnRequest::onRequest"); co_return std::nullopt; } folly::coro::Task onResponse( folly::Unit*, folly::Unit*, ResponseInfo) override { onResponseCount++; co_return; } int onRequestCount = 0; int onResponseCount = 0; }; struct ServiceInterceptorThrowOnResponse : public NamedServiceInterceptor { public: using NamedServiceInterceptor::NamedServiceInterceptor; folly::coro::Task> onRequest( folly::Unit*, RequestInfo) override { onRequestCount++; co_return std::nullopt; } folly::coro::Task onResponse( folly::Unit*, folly::Unit*, ResponseInfo) override { onResponseCount++; throw std::runtime_error( "Exception from ServiceInterceptorThrowOnResponse::onResponse"); co_return; } int onRequestCount = 0; int onResponseCount = 0; }; struct ServiceInterceptorRethrowActiveExceptionOnResponse : public NamedServiceInterceptor { public: using NamedServiceInterceptor::NamedServiceInterceptor; folly::coro::Task> onRequest( folly::Unit*, RequestInfo) override { onRequestCount++; co_return std::nullopt; } folly::coro::Task onResponse( folly::Unit*, folly::Unit*, ResponseInfo responseInfo) override { onResponseCount++; if (auto* ex = std::get_if( &responseInfo.resultOrActiveException)) { ex->throw_exception(); } co_return; } int onRequestCount = 0; int onResponseCount = 0; }; struct ServiceInterceptorLogResultTypeOnResponse : public NamedServiceInterceptor { public: using NamedServiceInterceptor::NamedServiceInterceptor; folly::coro::Task onResponse( folly::Unit*, folly::Unit*, ResponseInfo responseInfo) override { results.emplace_back(folly::variant_match( responseInfo.resultOrActiveException, [](const folly::exception_wrapper& ex) -> Entry { return Entry{ResultKind::EXCEPTION, *ex.type()}; }, [](const apache::thrift::util::TypeErasedRef& result) -> Entry { return Entry{ResultKind::OK, result.type()}; })); co_return; } enum class ResultKind { OK, EXCEPTION, }; struct Entry { ResultKind kind; std::type_index type; bool operator==(const Entry& other) const { return std::tie(kind, type) == std::tie(other.kind, other.type); } }; [[maybe_unused]] friend std::ostream& operator<<( std::ostream& os, const Entry& entry) { auto kindStr = entry.kind == ResultKind::OK ? "OK" : "EXCEPTION"; return os << "Entry(kind=" << kindStr << ", type=" << folly::demangle(entry.type.name()) << ")"; } std::vector results; }; } // namespace CO_TEST_P(ServiceInterceptorTestP, BasicTM) { auto interceptor = std::make_shared("Interceptor1"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); co_await client->co_echo(""); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); co_await client->co_echo(""); EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 2); } CO_TEST_P(ServiceInterceptorTestP, BasicEB) { auto interceptor = std::make_shared("Interceptor1"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); co_await client->co_echo_eb(""); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); co_await client->co_echo_eb(""); EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 2); } // void return calls HandlerCallback::done() instead of // HandlerCallback::result() CO_TEST_P(ServiceInterceptorTestP, BasicVoidReturn) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); // HTTP2 does not support onConnection and onConnectionClosed because // ThriftProcessor creates & disposes the Cpp2ConnContext every request, not // connection. const auto valueIfNotHttp2 = [&](int value) -> int { return transportType() == TransportType::HTTP2 ? 0 : value; }; { auto client = runner->newStickyClient< apache::thrift::Client>( nullptr /* callbackExecutor */, channelFor(transportType())); co_await client->co_noop(); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); EXPECT_EQ(interceptor->onConnectionCount, valueIfNotHttp2(1)); EXPECT_EQ(interceptor->onConnectionClosedCount, valueIfNotHttp2(0)); } co_await client->co_noop(); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 2); EXPECT_EQ(interceptor->onConnectionCount, valueIfNotHttp2(1)); EXPECT_EQ(interceptor->onConnectionClosedCount, valueIfNotHttp2(0)); } } runner.reset(); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onConnectionCount, valueIfNotHttp2(1)); EXPECT_EQ(interceptor->onConnectionClosedCount, valueIfNotHttp2(1)); } } CO_TEST_P(ServiceInterceptorTestP, NonTrivialRequestState) { struct Counts { int construct = 0; int destruct = 0; } counts; struct RequestState { explicit RequestState(Counts& counts) : counts_(&counts) { counts_->construct++; } RequestState(RequestState&& other) noexcept : counts_(std::exchange(other.counts_, nullptr)) {} RequestState& operator=(RequestState&& other) noexcept { counts_ = std::exchange(other.counts_, nullptr); return *this; } ~RequestState() { if (counts_) { counts_->destruct++; } } private: Counts* counts_; }; struct ServiceInterceptorNonTrivialRequestState : public NamedServiceInterceptor { public: ServiceInterceptorNonTrivialRequestState(std::string name, Counts& counts) : NamedServiceInterceptor(std::move(name)), counts_(counts) {} folly::coro::Task> onRequest( folly::Unit*, RequestInfo) override { co_return RequestState(counts_); } folly::coro::Task onResponse( RequestState*, folly::Unit*, ResponseInfo) override { co_return; } private: Counts& counts_; }; auto interceptor1 = std::make_shared( "Interceptor1", counts); auto interceptor2 = std::make_shared( "Interceptor2", counts); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); auto client = makeClient>(*runner); co_await client->co_noop(); EXPECT_EQ(counts.construct, 2); EXPECT_EQ(counts.destruct, 2); } CO_TEST_P(ServiceInterceptorTestP, IterationOrder) { int seq = 0; class ServiceInterceptorRecordingExecutionSequence : public NamedServiceInterceptor { public: using RequestState = folly::Unit; explicit ServiceInterceptorRecordingExecutionSequence( std::string name, int& seq) : NamedServiceInterceptor(std::move(name)), seq_(seq) {} folly::coro::Task> onRequest( folly::Unit*, RequestInfo) override { onRequestSeq = ++seq_; co_return std::nullopt; } folly::coro::Task onResponse( RequestState*, folly::Unit*, ResponseInfo) override { onResponseSeq = ++seq_; co_return; } int onRequestSeq = 0; int onResponseSeq = 0; private: int& seq_; }; auto interceptor1 = std::make_shared( "Interceptor1", seq); auto interceptor2 = std::make_shared( "Interceptor2", seq); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); auto client = makeClient>(*runner); co_await client->co_noop(); EXPECT_EQ(interceptor1->onRequestSeq, 1); EXPECT_EQ(interceptor2->onRequestSeq, 2); EXPECT_EQ(interceptor2->onResponseSeq, 3); EXPECT_EQ(interceptor1->onResponseSeq, 4); } TEST_P(ServiceInterceptorTestP, OnStartServing) { struct ServiceInterceptorCountOnStartServing : public NamedServiceInterceptor { public: using RequestState = folly::Unit; using NamedServiceInterceptor::NamedServiceInterceptor; folly::coro::Task co_onStartServing(InitParams) override { onStartServingCount++; co_return; } folly::coro::Task> onRequest( folly::Unit*, RequestInfo) override { co_return folly::unit; } folly::coro::Task onResponse( RequestState*, folly::Unit*, ResponseInfo) override { co_return; } int onStartServingCount = 0; }; auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onStartServingCount, 1); } } TEST_P(ServiceInterceptorTestP, DuplicateNameThrows) { auto interceptor1 = std::make_shared("Duplicate"); auto interceptor2 = std::make_shared("Duplicate"); EXPECT_THROW( { try { makeServer( std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); } catch (const std::logic_error& ex) { EXPECT_THAT(ex.what(), HasSubstr("TestModule.Duplicate")); throw; } }, std::logic_error); } CO_TEST_P(ServiceInterceptorTestP, OnRequestException) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto interceptor3 = std::make_shared("Interceptor3"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2, interceptor3})); }); auto client = makeClient>(*runner); EXPECT_THROW( { try { co_await client->co_echo(""); } catch (const apache::thrift::TApplicationException& ex) { EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onRequest threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr( "Exception from ServiceInterceptorThrowOnRequest::onRequest")); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor1]")); EXPECT_THAT( std::string(ex.what()), Not(HasSubstr("[TestModule.Interceptor2]"))); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor3]")); throw; } }, apache::thrift::TApplicationException); EXPECT_EQ(interceptor1->onRequestCount, 1); EXPECT_EQ(interceptor1->onResponseCount, 1); EXPECT_EQ(interceptor2->onRequestCount, 1); EXPECT_EQ(interceptor2->onResponseCount, 1); EXPECT_EQ(interceptor3->onRequestCount, 1); EXPECT_EQ(interceptor3->onResponseCount, 1); } CO_TEST_P(ServiceInterceptorTestP, OnRequestExceptionEB) { auto interceptor = std::make_shared("Interceptor1"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); EXPECT_THROW( { try { co_await client->co_echo_eb(""); } catch (const apache::thrift::TApplicationException& ex) { EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onRequest threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr( "Exception from ServiceInterceptorThrowOnRequest::onRequest")); throw; } }, apache::thrift::TApplicationException); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); } CO_TEST_P(ServiceInterceptorTestP, OnResponseException) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto interceptor3 = std::make_shared("Interceptor3"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2, interceptor3})); }); auto client = makeClient>(*runner); EXPECT_THROW( { try { co_await client->co_echo(""); } catch (const apache::thrift::TApplicationException& ex) { EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onResponse threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr( "Exception from ServiceInterceptorThrowOnResponse::onResponse")); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor1]")); EXPECT_THAT( std::string(ex.what()), Not(HasSubstr("[TestModule.Interceptor2]"))); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor3]")); throw; } }, apache::thrift::TApplicationException); EXPECT_EQ(interceptor1->onRequestCount, 1); EXPECT_EQ(interceptor1->onResponseCount, 1); EXPECT_EQ(interceptor2->onRequestCount, 1); EXPECT_EQ(interceptor2->onResponseCount, 1); EXPECT_EQ(interceptor3->onRequestCount, 1); EXPECT_EQ(interceptor3->onResponseCount, 1); } CO_TEST_P(ServiceInterceptorTestP, OnResponseExceptionEB) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); auto client = makeClient>(*runner); EXPECT_THROW( { try { co_await client->co_echo_eb(""); } catch (const apache::thrift::TApplicationException& ex) { EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onResponse threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr( "Exception from ServiceInterceptorThrowOnResponse::onResponse")); throw; } }, apache::thrift::TApplicationException); EXPECT_EQ(interceptor1->onRequestCount, 1); EXPECT_EQ(interceptor1->onResponseCount, 1); EXPECT_EQ(interceptor2->onRequestCount, 1); EXPECT_EQ(interceptor2->onResponseCount, 1); } CO_TEST_P( ServiceInterceptorTestP, OnResponseBypassedForUnsafeReleasedCallback) { auto interceptor = std::make_shared("Interceptor1"); struct TestHandlerUnsafeReleaseCallback : apache::thrift::ServiceHandler { void async_tm_echo( apache::thrift::HandlerCallbackPtr> callback, std::unique_ptr str) override { std::unique_ptr< apache::thrift::HandlerCallback>> releasedCallback{callback.unsafeRelease()}; releasedCallback->result(*str); } }; auto runner = makeServer( std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); co_await client->co_echo(""); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 0); co_await client->co_echo(""); EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 0); } CO_TEST_P( ServiceInterceptorTestP, OnResponseExceptionSwallowsApplicationException) { auto interceptor = std::make_shared( "Interceptor1"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); EXPECT_THROW( { try { co_await client->co_echo("throw"); } catch (const apache::thrift::TApplicationException& ex) { EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onResponse threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor1]")); EXPECT_THAT(std::string(ex.what()), HasSubstr("You asked for it!")); throw; } }, apache::thrift::TApplicationException); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); } CO_TEST_P( ServiceInterceptorTestP, OnResponseExceptionSwallowsOnRequestException) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared( "Interceptor2"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); auto client = makeClient>(*runner); EXPECT_THROW( { try { co_await client->co_noop(); } catch (const apache::thrift::TApplicationException& ex) { EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onRequest threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor1]")); EXPECT_THAT( std::string(ex.what()), HasSubstr("ServiceInterceptor::onResponse threw exceptions")); EXPECT_THAT( std::string(ex.what()), HasSubstr("[TestModule.Interceptor2]")); throw; } }, apache::thrift::TApplicationException); EXPECT_EQ(interceptor1->onRequestCount, 1); EXPECT_EQ(interceptor1->onResponseCount, 1); EXPECT_EQ(interceptor2->onRequestCount, 1); EXPECT_EQ(interceptor2->onResponseCount, 1); } CO_TEST_P(ServiceInterceptorTestP, BasicInteraction) { if (transportType() != TransportType::ROCKET) { // only rocket supports interactions co_return; } auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); auto client = makeClient>(*runner); { auto interaction = co_await client->co_createInteraction(); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); } co_await interaction.co_echo(""); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 2); } co_await client->co_echo(""); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 3); EXPECT_EQ(interceptor->onResponseCount, 3); } } for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 3); EXPECT_EQ(interceptor->onResponseCount, 3); } } CO_TEST_P(ServiceInterceptorTestP, BasicStream) { if (transportType() != TransportType::ROCKET) { // only rocket supports streaming co_return; } auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique( InterceptorList{interceptor1, interceptor2})); }); auto client = makeClient>(*runner); { auto stream = (co_await client->co_iota(1)).toAsyncGenerator(); EXPECT_EQ((co_await stream.next()).value(), 1); EXPECT_EQ((co_await stream.next()).value(), 2); for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); } // close stream } for (auto& interceptor : {interceptor1, interceptor2}) { EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); } } CO_TEST_P(ServiceInterceptorTestP, RequestArguments) { struct ServiceInterceptorWithRequestArguments : public NamedServiceInterceptor { public: using ConnectionState = folly::Unit; using RequestState = folly::Unit; using NamedServiceInterceptor::NamedServiceInterceptor; folly::coro::Task> onRequest( ConnectionState*, RequestInfo requestInfo) override { argsCount = requestInfo.arguments.count(); arg1 = requestInfo.arguments.get(0)->value(); arg2 = requestInfo.arguments.get(1)->value(); arg3 = requestInfo.arguments.get(2)->value(); EXPECT_THROW( requestInfo.arguments.get(2)->value(), std::bad_cast); EXPECT_FALSE(requestInfo.arguments.get(3).has_value()); co_return std::nullopt; } std::size_t argsCount = 0; std::int32_t arg1 = 0; std::string arg2; test::RequestArgsStruct arg3; }; auto interceptor = std::make_shared("Interceptor1"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); test::RequestArgsStruct requestArgs; requestArgs.foo() = 1; requestArgs.bar() = "hello"; auto result = co_await client->co_requestArgs(1, "hello", requestArgs); EXPECT_EQ(interceptor->argsCount, 3); EXPECT_EQ(interceptor->arg1, 1); EXPECT_EQ(interceptor->arg2, "hello"); EXPECT_EQ(interceptor->arg3, requestArgs); } CO_TEST_P(ServiceInterceptorTestP, ServiceAndMethodNames) { struct ServiceInterceptorCheckingServiceAndMethodNames : public NamedServiceInterceptor { public: using ConnectionState = folly::Unit; using RequestState = folly::Unit; ServiceInterceptorCheckingServiceAndMethodNames() : NamedServiceInterceptor("SomeName") {} folly::coro::Task> onRequest( ConnectionState*, RequestInfo requestInfo) override { names.emplace_back( requestInfo.serviceName ? std::string(requestInfo.serviceName) : "", requestInfo.methodName ? std::string(requestInfo.methodName) : ""); co_return std::nullopt; } using Entry = std::pair; std::vector names; }; auto interceptor = std::make_shared(); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); co_await client->co_echo(""); co_await client->co_noop(); std::vector expectedNames = { {"ServiceInterceptorTest", "echo"}, {"ServiceInterceptorTest", "noop"}, }; if (transportType() == TransportType::ROCKET) { // only rocket supports interactions auto interaction = co_await client->co_createInteraction(); co_await interaction.co_echo(""); expectedNames.emplace_back("ServiceInterceptorTest", "createInteraction"); expectedNames.emplace_back( "ServiceInterceptorTest", "SampleInteraction.echo"); } EXPECT_THAT(interceptor->names, ElementsAreArray(expectedNames)); } CO_TEST_P(ServiceInterceptorTestP, ResultOrActiveExceptionTypesAreCorrect) { if (transportType() != TransportType::ROCKET) { // only rocket supports all transport features being tested here co_return; } auto interceptor = std::make_shared( "Interceptor1"); auto runner = makeServer(std::make_shared(), [&](ThriftServer& server) { server.addModule(std::make_unique(interceptor)); }); auto client = makeClient>(*runner); co_await client->co_echo(""); co_await client->co_noop(); { auto stream = (co_await client->co_iota(1)).toAsyncGenerator(); EXPECT_EQ((co_await stream.next()).value(), 1); EXPECT_EQ((co_await stream.next()).value(), 2); // close stream } { auto interaction = co_await client->co_createInteraction(); co_await interaction.co_echo(""); // terminate interaction } { auto interaction = client->createSampleInteraction2(); co_await interaction.co_echo(""); // terminate interaction } EXPECT_THROW( co_await client->co_echo("throw"), apache::thrift::TApplicationException); co_await client->co_echo_eb(""); { test::RequestArgsStruct requestArgs; requestArgs.foo() = 1; requestArgs.bar() = "hello"; co_await client->co_echoStruct(requestArgs); } using ResultKind = ServiceInterceptorLogResultTypeOnResponse::ResultKind; std::vector expectedResults = { // echo {ResultKind::OK, typeid(std::string)}, // noop {ResultKind::OK, typeid(folly::Unit)}, // iota {ResultKind::OK, typeid(apache::thrift::ServerStream)}, // createInteraction {ResultKind::OK, typeid(folly::Unit)}, // SampleInteraction.echo {ResultKind::OK, typeid(std::string)}, // SampleInteraction2.echo {ResultKind::OK, typeid(std::string)}, // echo("throw") {ResultKind::EXCEPTION, typeid(std::runtime_error)}, // echo_eb {ResultKind::OK, typeid(std::string)}, // echoStruct {ResultKind::OK, typeid(test::ResponseArgsStruct)}, }; EXPECT_THAT(interceptor->results, ElementsAreArray(expectedResults)); } INSTANTIATE_TEST_SUITE_P( ServiceInterceptorTestP, ServiceInterceptorTestP, ::testing::Values( TransportType::HEADER, TransportType::ROCKET, TransportType::HTTP2));