/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // @manual=//thrift/lib/python/client:omni_client__cython-lib #include #include #include using namespace apache::thrift; using namespace thrift::python::client; using namespace thrift::python::test; const std::string kTestHeaderKey = "headerKey"; const std::string kTestHeaderValue = "headerValue"; /** * A simple Scaffold service that will be used to test the Thrift OmniClient. */ class TestServiceHandler : virtual public apache::thrift::ServiceHandler { public: TestServiceHandler() {} virtual ~TestServiceHandler() override {} int add(int num1, int num2) override { return num1 + num2; } void oneway() override {} void readHeader( std::string& value, std::unique_ptr key) override { value = getRequestContext()->getHeader()->getHeaders().at(*key); } ServerStream nums(int f, int t) override { if (t < f) { ArithmeticException e; e.msg_ref() = "my_magic_arithmetic_exception"; throw e; } return folly::coro::co_invoke( [f, t]() -> folly::coro::AsyncGenerator { for (int i = f; i <= t; ++i) { SimpleResponse r; r.value() = std::to_string(i); co_yield std::move(r); } if (f < 0) { throw std::logic_error("negative_number_detected"); } ArithmeticException e; e.msg_ref() = "throw_from_inside_stream"; throw e; }); } ResponseAndServerStream sumAndNums( int f, int t) override { if (t < f) { ArithmeticException e; e.msg_ref() = "my_magic_arithmetic_exception"; throw e; } return { (f + t) * (t - f + 1) / 2, folly::coro::co_invoke( [f, t]() -> folly::coro::AsyncGenerator { for (int i = f; i <= t; ++i) { SimpleResponse r; r.value() = std::to_string(i); co_yield std::move(r); } }), }; } ResponseAndSinkConsumer dumbSink( std::unique_ptr request) override { (void)request; SinkConsumer consumer{ [&](folly::coro::AsyncGenerator gen) -> folly::coro::Task { SimpleResponse response; response.value_ref() = "final"; co_return response; }, 1}; SimpleResponse response; response.value_ref() = "initial"; return {std::move(response), std::move(consumer)}; } }; /** * Small event-handler to know when a server is ready. */ class ServerReadyEventHandler : public server::TServerEventHandler { public: void preServe(const folly::SocketAddress* address) override { port_ = address->getPort(); baton_.post(); } int32_t waitForPortAssignment() { baton_.wait(); return port_; } private: folly::Baton<> baton_; int32_t port_; }; std::unique_ptr createServer( std::shared_ptr processorFactory, uint16_t& port) { auto server = std::make_unique(); server->setPort(0); server->setInterface(std::move(processorFactory)); server->setNumIOWorkerThreads(1); server->setNumCPUWorkerThreads(1); server->setQueueTimeout(std::chrono::milliseconds(0)); server->setIdleTimeout(std::chrono::milliseconds(0)); server->setTaskExpireTime(std::chrono::milliseconds(0)); server->setStreamExpireTime(std::chrono::milliseconds(0)); auto eventHandler = std::make_shared(); server->setServerEventHandler(eventHandler); server->setup(); // Get the port that the server has bound to port = eventHandler->waitForPortAssignment(); return server; } class OmniClientTest : public ::testing::Test { protected: void SetUp() override { // Startup the test server. server_ = createServer(std::make_shared(), serverPort_); } void TearDown() override { // Stop the server and wait for it to complete. if (server_) { server_->cleanUp(); server_.reset(); } } template void connectToServer( folly::Function(OmniClient&)> callMe) { constexpr protocol::PROTOCOL_TYPES prot = std::is_same_v ? protocol::T_BINARY_PROTOCOL : protocol::T_COMPACT_PROTOCOL; folly::coro::blockingWait([this, &callMe]() -> folly::coro::Task { CHECK_GT(serverPort_, 0) << "Check if the server has started already"; folly::Executor* executor = co_await folly::coro::co_current_executor; auto channel = PooledRequestChannel::newChannel( executor, ioThread_, [this](folly::EventBase& evb) { auto chan = apache::thrift::RocketClientChannel::newChannel( folly::AsyncSocket::UniquePtr( new folly::AsyncSocket(&evb, "::1", serverPort_))); chan->setProtocolId(prot); chan->setTimeout(500 /* ms */); return chan; }, prot); OmniClient client(std::move(channel)); co_await callMe(client); }()); } // Send a request and compare the results to the expected value. template void testSendHeaders( const std::string& service, const std::string& function, const Request& req, const std::unordered_map& headers, const Result& expected, const RpcKind rpcKind = RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE, const bool clearEventHandlers = false) { connectToServer([=](OmniClient& client) -> folly::coro::Task { if (clearEventHandlers) { client.clearEventHandlers(); } std::string args = S::template serialize(req); auto data = apache::thrift::MethodMetadata::Data( function, apache::thrift::FunctionQualifier::Unspecified); auto resp = co_await client.semifuture_send( service, function, args, std::move(data), headers, {}, co_await folly::coro::co_current_executor, rpcKind); testContains(std::move(*resp.buf.value()), expected); }); } template void testSend( const std::string& service, const std::string& function, const Request& req, const Result& expected, const RpcKind rpcKind = RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE, const bool clearEventHandlers = false) { testSendHeaders( service, function, req, {}, expected, rpcKind, clearEventHandlers); } // Send a request and compare the results to the expected value. template void testOnewaySendHeaders( const std::string& service, const std::string& function, const Request& req, const std::unordered_map& headers = {}) { connectToServer([=](OmniClient& client) -> folly::coro::Task { std::string args = S::template serialize(req); auto data = apache::thrift::MethodMetadata::Data( function, apache::thrift::FunctionQualifier::Unspecified); client.oneway_send(service, function, args, std::move(data), headers, {}); co_return; }); } template void testContains(folly::IOBuf buf, const T& expected) { std::string expectedStr = S::template serialize(expected); std::string result = buf.to(); // Contains instead of equals because of the envelope around the response. EXPECT_THAT(result, testing::HasSubstr(expectedStr)); } template void testSendStream( const std::string& service, const std::string& function, const Request& req, folly::Function(OmniClientResponseWithHeaders&&)> onResponse) { connectToServer( [&](OmniClient& client) mutable -> folly::coro::Task { std::string args = S::template serialize(req); auto data = apache::thrift::MethodMetadata::Data( function, apache::thrift::FunctionQualifier::Unspecified); co_await onResponse(co_await client.semifuture_send( service, function, args, std::move(data), {}, {}, co_await folly::coro::co_current_executor, RpcKind::SINGLE_REQUEST_STREAMING_RESPONSE)); }); } protected: std::unique_ptr server_; folly::EventBase* eb_ = folly::EventBaseManager::get()->getEventBase(); uint16_t serverPort_{0}; std::shared_ptr ioThread_{ std::make_shared()}; }; TEST_F(OmniClientTest, AddTestFailsWithBadEventHandler) { AddRequest request; request.num1_ref() = 1; request.num2_ref() = 41; addHandler(); EXPECT_THROW( { testSend("TestService", "add", request, 42); testSend("TestService", "add", request, 42); }, folly::BadExpectedAccess); } TEST_F(OmniClientTest, AddTestPassesWhenBadEventHandlerIsCleared) { AddRequest request; request.num1_ref() = 1; request.num2_ref() = 41; addHandler(); testSend( "TestService", "add", request, 42, RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE, true); testSend( "TestService", "add", request, 42, RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE, true); } TEST_F(OmniClientTest, AddTest) { AddRequest request; request.num1_ref() = 1; request.num2_ref() = 41; testSend("TestService", "add", request, 42); testSend("TestService", "add", request, 42); } TEST_F(OmniClientTest, OnewayTest) { EmptyRequest request; testOnewaySendHeaders("TestService", "oneway", request); testOnewaySendHeaders("TestService", "oneway", request); } TEST_F(OmniClientTest, ReadHeaderTest) { ReadHeaderRequest request; request.key() = kTestHeaderKey; testSendHeaders( "TestService", "readHeader", request, {{kTestHeaderKey, kTestHeaderValue}}, kTestHeaderValue); } TEST_F(OmniClientTest, SinkRequestTest) { EmptyRequest request; SimpleResponse response; response.value_ref() = "initial"; testSend("TestService", "dumbSink", request, response, RpcKind::SINK); } TEST_F(OmniClientTest, StreamNumsTest) { NumsRequest request; request.f() = 2; request.t() = 4; testSendStream( "TestService", "nums", request, [this](OmniClientResponseWithHeaders&& resp) -> folly::coro::Task { auto gen = std::move(*resp.stream).toAsyncGenerator(); for (int i = 2; i <= 4; ++i) { auto val = co_await gen.next(); EXPECT_TRUE(val); testContains(std::move(*val), std::to_string(i)); } auto val = co_await gen.next(); testContains( std::move(*val), std::string{"throw_from_inside_stream"}); }); } TEST_F(OmniClientTest, StreamNumsUndeclaredExceptionTest) { NumsRequest request; request.f() = -1; request.t() = 4; testSendStream( "TestService", "nums", request, [this](OmniClientResponseWithHeaders&& resp) -> folly::coro::Task { auto gen = std::move(*resp.stream).toAsyncGenerator(); for (int i = -1; i <= 4; ++i) { auto val = co_await gen.next(); EXPECT_TRUE(val); testContains(std::move(*val), std::to_string(i)); } EXPECT_THROW(co_await gen.next(), TApplicationException); }); } TEST_F(OmniClientTest, StreamSumAndNumsTest) { NumsRequest request; request.f() = 2; request.t() = 4; testSendStream( "TestService", "sumAndNums", request, [this](OmniClientResponseWithHeaders&& resp) -> folly::coro::Task { testContains( std::move(*resp.buf.value()), 9); auto gen = std::move(*resp.stream).toAsyncGenerator(); for (int i = 2; i <= 4; ++i) { auto val = co_await gen.next(); EXPECT_TRUE(val); testContains(std::move(*val), std::to_string(i)); } EXPECT_FALSE(co_await gen.next()); }); } TEST_F(OmniClientTest, StreamSumAndNumsExceptionTest) { NumsRequest request; request.f() = 4; request.t() = 2; testSendStream( "TestService", "sumAndNums", request, [this](OmniClientResponseWithHeaders&& resp) -> folly::coro::Task { testContains( std::move(*resp.buf.value()), std::string{"my_magic_arithmetic_exception"}); auto gen = std::move(*resp.stream).toAsyncGenerator(); EXPECT_FALSE(co_await gen.next()); }); }