/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace apache::thrift; using namespace ::testing; namespace { using TransportType = Cpp2ConnContext::TransportType; std::unique_ptr createHTTP2RoutingHandler( ThriftServer& server) { auto h2Options = std::make_unique(); h2Options->threads = static_cast(server.getNumIOWorkerThreads()); h2Options->idleTimeout = server.getIdleTimeout(); h2Options->shutdownOn = {SIGINT, SIGTERM}; return std::make_unique( std::move(h2Options), server.getThriftProcessor(), server); } enum class ClientCallbackKind { CORO, SYNC }; struct TestHandler : apache::thrift::ServiceHandler { folly::coro::Task co_noop() override { co_return; } folly::coro::Task> co_echo( std::unique_ptr str) override { if (*str == "throw") { throw std::runtime_error("You asked for it!"); } co_return std::move(str); } }; class ClientInterface { public: explicit ClientInterface( std::unique_ptr> client) : client_(std::move(client)) {} virtual ~ClientInterface() = default; virtual folly::coro::Task echo(std::string str) = 0; virtual folly::coro::Task noop() = 0; protected: std::unique_ptr> client_; }; class CoroClientInterface : public ClientInterface { public: using ClientInterface::ClientInterface; folly::coro::Task echo(std::string str) override { co_return co_await client_->co_echo(std::move(str)); } folly::coro::Task noop() override { co_await client_->co_noop(); co_return; } }; class SyncClientInterface : public ClientInterface { public: using ClientInterface::ClientInterface; folly::coro::Task echo(std::string str) override { std::string ret; client_->sync_echo(ret, std::move(str)); co_return ret; } folly::coro::Task noop() override { client_->sync_noop(); co_return; } }; class ClientInterceptorTestP : public ::testing::TestWithParam< std::tuple> { public: TransportType transportType() const { return std::get<0>(GetParam()); } ClientCallbackKind clientCallbackType() const { return std::get<1>(GetParam()); } private: void SetUp() override { runner = std::make_unique( std::make_shared()); if (transportType() == TransportType::HTTP2) { auto& thriftServer = runner->getThriftServer(); thriftServer.addRoutingHandler(createHTTP2RoutingHandler(thriftServer)); } } std::unique_ptr runner; ScopedServerInterfaceThread::MakeChannelFunc channelFor( TransportType transportType) { return [transportType]( folly::AsyncSocket::UniquePtr socket) -> RequestChannel::Ptr { switch (transportType) { case TransportType::HEADER: return HeaderClientChannel::newChannel(std::move(socket)); case TransportType::ROCKET: return RocketClientChannel::newChannel(std::move(socket)); case TransportType::HTTP2: { auto channel = HTTPClientChannel::newHTTP2Channel(std::move(socket)); channel->setProtocolId(protocol::T_COMPACT_PROTOCOL); return channel; } default: throw std::logic_error{"Unreachable!"}; } }; } std::shared_ptr makeChannel() { return runner ->newClient>( nullptr, channelFor(transportType())) ->getChannelShared(); } public: std::unique_ptr makeClient( std::shared_ptr>> interceptors) { auto client = std::make_unique>( makeChannel(), std::move(interceptors)); switch (clientCallbackType()) { case ClientCallbackKind::CORO: return std::make_unique(std::move(client)); case ClientCallbackKind::SYNC: return std::make_unique(std::move(client)); default: throw std::logic_error{"Unknown client callback type!"}; } } }; template std::shared_ptr>> makeInterceptorsList(InterceptorPtrs&&... interceptors) { auto list = std::make_shared>>(); (list->emplace_back(std::forward(interceptors)), ...); return list; } template struct NamedClientInterceptor : public ClientInterceptor { explicit NamedClientInterceptor(std::string name) : name_(std::move(name)) {} std::string getName() const override { return name_; } private: std::string name_; }; class ClientInterceptorCountWithRequestState : public NamedClientInterceptor { public: using RequestState = int; using NamedClientInterceptor::NamedClientInterceptor; folly::coro::Task> onRequest( RequestInfo) override { onRequestCount++; co_return 1; } folly::coro::Task onResponse( RequestState* requestState, ResponseInfo) override { onResponseCount += *requestState; co_return; } int onRequestCount = 0; int onResponseCount = 0; }; class ClientInterceptorThatThrowsOnRequest : public ClientInterceptorCountWithRequestState { public: using ClientInterceptorCountWithRequestState:: ClientInterceptorCountWithRequestState; folly::coro::Task> onRequest( RequestInfo requestInfo) override { co_await ClientInterceptorCountWithRequestState::onRequest( std::move(requestInfo)); throw std::runtime_error("Oh no!"); } }; class ClientInterceptorThatThrowsOnResponse : public ClientInterceptorCountWithRequestState { public: using ClientInterceptorCountWithRequestState:: ClientInterceptorCountWithRequestState; folly::coro::Task onResponse( RequestState* requestState, ResponseInfo responseInfo) override { co_await ClientInterceptorCountWithRequestState::onResponse( requestState, std::move(responseInfo)); throw std::runtime_error("Oh no!"); } }; } // namespace CO_TEST_P(ClientInterceptorTestP, Basic) { auto interceptor = std::make_shared("Interceptor1"); auto client = makeClient(makeInterceptorsList(interceptor)); co_await client->echo("foo"); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); co_await client->noop(); EXPECT_EQ(interceptor->onRequestCount, 2); EXPECT_EQ(interceptor->onResponseCount, 2); } CO_TEST_P(ClientInterceptorTestP, OnRequestException) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto interceptor3 = std::make_shared("Interceptor3"); auto client = makeClient( makeInterceptorsList(interceptor1, interceptor2, interceptor3)); EXPECT_THROW( { try { co_await client->noop(); } catch (const ClientInterceptorException& ex) { EXPECT_EQ(ex.causes().size(), 2); EXPECT_EQ(ex.causes()[0].sourceInterceptorName, "Interceptor1"); EXPECT_EQ(ex.causes()[1].sourceInterceptorName, "Interceptor3"); EXPECT_THAT(ex.what(), HasSubstr("[Interceptor1]")); EXPECT_THAT(ex.what(), Not(HasSubstr("Interceptor2"))); EXPECT_THAT(ex.what(), HasSubstr("[Interceptor3]")); EXPECT_THAT(ex.what(), HasSubstr("ClientInterceptor::onRequest")); throw; } }, ClientInterceptorException); EXPECT_EQ(interceptor1->onRequestCount, 1); EXPECT_EQ(interceptor2->onRequestCount, 1); EXPECT_EQ(interceptor3->onRequestCount, 1); EXPECT_EQ(interceptor1->onResponseCount, 0); EXPECT_EQ(interceptor2->onResponseCount, 0); EXPECT_EQ(interceptor3->onResponseCount, 0); } CO_TEST_P(ClientInterceptorTestP, IterationOrder) { int seq = 1; class ClientInterceptorRecordingExecutionSequence : public NamedClientInterceptor { public: using RequestState = folly::Unit; ClientInterceptorRecordingExecutionSequence(std::string name, int& seq) : NamedClientInterceptor(std::move(name)), seq_(seq) {} folly::coro::Task> onRequest( RequestInfo) override { onRequestSeq = seq_++; co_return std::nullopt; } folly::coro::Task onResponse(RequestState*, ResponseInfo) override { onResponseSeq = seq_++; co_return; } int onRequestSeq = 0; int onResponseSeq = 0; private: int& seq_; }; auto interceptor1 = std::make_shared( "Interceptor1", seq); auto interceptor2 = std::make_shared( "Interceptor2", seq); auto client = makeClient(makeInterceptorsList(interceptor1, interceptor2)); co_await client->noop(); EXPECT_EQ(interceptor1->onRequestSeq, 1); EXPECT_EQ(interceptor2->onRequestSeq, 2); EXPECT_EQ(interceptor2->onResponseSeq, 3); EXPECT_EQ(interceptor1->onResponseSeq, 4); } CO_TEST_P(ClientInterceptorTestP, OnResponseException) { auto interceptor1 = std::make_shared("Interceptor1"); auto interceptor2 = std::make_shared("Interceptor2"); auto interceptor3 = std::make_shared("Interceptor3"); auto client = makeClient( makeInterceptorsList(interceptor1, interceptor2, interceptor3)); EXPECT_THROW( { try { co_await client->noop(); } catch (const ClientInterceptorException& ex) { EXPECT_EQ(ex.causes().size(), 2); EXPECT_EQ(ex.causes()[0].sourceInterceptorName, "Interceptor3"); EXPECT_EQ(ex.causes()[1].sourceInterceptorName, "Interceptor1"); EXPECT_THAT(ex.what(), HasSubstr("[Interceptor1]")); EXPECT_THAT(ex.what(), Not(HasSubstr("Interceptor2"))); EXPECT_THAT(ex.what(), HasSubstr("[Interceptor3]")); EXPECT_THAT(ex.what(), HasSubstr("ClientInterceptor::onResponse")); throw; } }, ClientInterceptorException); EXPECT_EQ(interceptor1->onRequestCount, 1); EXPECT_EQ(interceptor2->onRequestCount, 1); EXPECT_EQ(interceptor3->onRequestCount, 1); EXPECT_EQ(interceptor1->onResponseCount, 1); EXPECT_EQ(interceptor2->onResponseCount, 1); EXPECT_EQ(interceptor3->onResponseCount, 1); } CO_TEST_P( ClientInterceptorTestP, OnResponseExceptionSwallowsApplicationException) { auto interceptor = std::make_shared("Interceptor1"); auto client = makeClient(makeInterceptorsList(interceptor)); EXPECT_THROW( { try { co_await client->echo("throw"); } catch (const apache::thrift::ClientInterceptorException& ex) { EXPECT_THAT(std::string(ex.what()), HasSubstr("[Interceptor1]")); throw; } }, apache::thrift::ClientInterceptorException); EXPECT_EQ(interceptor->onRequestCount, 1); EXPECT_EQ(interceptor->onResponseCount, 1); } INSTANTIATE_TEST_SUITE_P( ClientInterceptorTestP, ClientInterceptorTestP, Combine( Values( TransportType::HEADER, TransportType::ROCKET, TransportType::HTTP2), Values(ClientCallbackKind::CORO, ClientCallbackKind::SYNC)), [](const TestParamInfo& info) { const auto transportType = [](TransportType value) -> std::string_view { switch (value) { case TransportType::HEADER: return "HEADER"; case TransportType::ROCKET: return "ROCKET"; case TransportType::HTTP2: return "HTTP2"; default: throw std::logic_error{"Unreachable!"}; } }; const auto clientCallbackType = [](ClientCallbackKind value) -> std::string_view { switch (value) { case ClientCallbackKind::CORO: return "CORO"; case ClientCallbackKind::SYNC: return "SYNC"; default: throw std::logic_error{"Unreachable!"}; } }; return fmt::format( "{}___{}", transportType(std::get<0>(info.param)), clientCallbackType(std::get<1>(info.param))); });