/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace testing; using namespace apache::thrift; using namespace apache::thrift::test; using folly::AsyncSocket; class TestServiceServerMock : public apache::thrift::ServiceHandler { public: MOCK_METHOD(int32_t, echoInt, (int32_t), (override)); MOCK_METHOD( folly::SemiFuture>, semifuture_echoRequest, (std::unique_ptr), (override)); MOCK_METHOD( folly::SemiFuture>, semifuture_echoIOBufAsByteStream, (std::unique_ptr, int32_t), (override)); MOCK_METHOD( folly::SemiFuture, semifuture_noResponse, (int64_t), (override)); }; struct CalculatorHandler : apache::thrift::ServiceHandler { struct AdditionHandler : apache::thrift::ServiceHandler::AdditionIf { int acc_{0}; Point pacc_; #if FOLLY_HAS_COROUTINES folly::coro::Task co_accumulatePrimitive(int32_t a) override { acc_ += a; co_return; } folly::coro::Task co_noop() override { co_return; } folly::coro::Task co_accumulatePoint( std::unique_ptr<::apache::thrift::test::Point> a) override { *pacc_.x_ref() += *a->x_ref(); *pacc_.y_ref() += *a->y_ref(); co_return; } folly::coro::Task co_getPrimitive() override { co_return acc_; } folly::coro::Task> co_getPoint() override { co_return folly::copy_to_unique_ptr(pacc_); } #endif }; std::unique_ptr createAddition() override { return std::make_unique(); } folly::SemiFuture semifuture_addPrimitive( int32_t a, int32_t b) override { return a + b; } }; class GuardedRequestChannelTest : public Test { private: RequestChannel::Ptr createGuardedRequestChannel() { auto pooledChannel = PooledRequestChannel::newChannel( evbThread->getEventBase(), evbThread, [&](folly::EventBase& evb) { auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb, serverAddress)); auto rocketChannel = RocketClientChannel::newChannel(std::move(socket)); return rocketChannel; }); return GuardedRequestChannel::newChannel( std::move(pooledChannel)); } protected: std::shared_ptr handler{ std::make_shared()}; std::unique_ptr runner{ std::make_unique(handler)}; folly::SocketAddress serverAddress{runner->getAddress()}; std::shared_ptr evbThread{ std::make_shared()}; apache::thrift::Client testClient{createGuardedRequestChannel()}; }; TEST_F(GuardedRequestChannelTest, normalSingleRequestSuccess) { EXPECT_CALL(*handler, echoInt(_)).WillOnce(Return(1)).WillOnce(Return(2)); EXPECT_EQ(testClient.sync_echoInt(1), 1); EXPECT_EQ(testClient.sync_echoInt(2), 2); } TEST_F(GuardedRequestChannelTest, sendRequestNoResponse) { EXPECT_NO_THROW(testClient.semifuture_noResponse(1)); EXPECT_NO_THROW(testClient.semifuture_noResponse(100)); } TEST_F(GuardedRequestChannelTest, normalStreamResponseAndComplete) { EXPECT_CALL(*handler, semifuture_echoIOBufAsByteStream(_, _)) .WillOnce(Invoke([&](std::unique_ptr buf, int32_t delayMs) { auto [stream, publisher] = ServerStream::createPublisher(); folly::io::Cursor cursor(buf.get()); int8_t byte; while (cursor.tryRead(byte)) { publisher.next(byte); } std::move(publisher).complete(); return folly::makeSemiFuture(std::move(stream)) .delayed(std::chrono::milliseconds(delayMs)); })); auto payloadLength = 25; auto iobuf = folly::IOBuf::copyBuffer(std::string(payloadLength, 'x')); auto stream = testClient.sync_echoIOBufAsByteStream(*iobuf, 5); auto returnPayload = 0; std::move(stream).subscribeInline([&](auto&& val) { if (val.hasValue()) { returnPayload++; EXPECT_EQ(*val, 'x'); } }); EXPECT_EQ(returnPayload, payloadLength); } TEST_F(GuardedRequestChannelTest, streamErrorFromServer) { EXPECT_CALL(*handler, semifuture_echoIOBufAsByteStream(_, _)) .WillOnce( Invoke([&](std::unique_ptr /*buf*/, int32_t delayMs) { auto [stream, publisher] = ServerStream::createPublisher(); auto ew = folly::exception_wrapper{ std::runtime_error("end stream immediately")}; std::move(publisher).complete(ew); return folly::makeSemiFuture(std::move(stream)) .delayed(std::chrono::milliseconds(delayMs)); })); auto iobuf = folly::IOBuf::copyBuffer(std::string(1, 'x')); auto stream = testClient.sync_echoIOBufAsByteStream(*iobuf, 0); std::move(stream).subscribeInline([&](auto&& val) { if (val.hasValue()) { FAIL() << "No real payload should be received"; } EXPECT_TRUE(val.hasException()); auto exception = val.tryGetExceptionObject(); EXPECT_THAT(exception->what(), HasSubstr("end stream immediately")); EXPECT_THAT(exception->what(), HasSubstr("std::runtime_error")); }); } TEST_F(GuardedRequestChannelTest, createInteractionTest) { ScopedServerInterfaceThread runner{std::make_shared()}; apache::thrift::Client client( GuardedRequestChannel::newChannel( PooledRequestChannel::newChannel([&](folly::EventBase& evb) { return RocketClientChannel::newChannel( folly::AsyncSocket::UniquePtr( new folly::AsyncSocket(&evb, runner.getAddress()))); }))); auto adder = client.createAddition(); } TEST_F(GuardedRequestChannelTest, basicRequestResponseMethodCallInteraction) { ScopedServerInterfaceThread runner{std::make_shared()}; apache::thrift::Client client( GuardedRequestChannel::newChannel( PooledRequestChannel::newChannel([&](folly::EventBase& evb) { return RocketClientChannel::newChannel( folly::AsyncSocket::UniquePtr( new folly::AsyncSocket(&evb, runner.getAddress()))); }))); auto adder = client.createAddition(); #if FOLLY_HAS_COROUTINES folly::coro::blockingWait([&]() -> folly::coro::Task { co_await adder.co_accumulatePrimitive(1); co_await adder.semifuture_accumulatePrimitive(2); auto acc = co_await adder.co_getPrimitive(); EXPECT_EQ(acc, 3); auto sum = co_await client.co_addPrimitive(20, 22); EXPECT_EQ(sum, 42); Point p; p.x_ref() = 1; co_await adder.co_accumulatePoint(p); p.y_ref() = 2; co_await adder.co_accumulatePoint(p); auto pacc = co_await adder.co_getPoint(); EXPECT_EQ(*pacc.x_ref(), 2); EXPECT_EQ(*pacc.y_ref(), 2); }()); #endif } TEST_F(GuardedRequestChannelTest, interactionConstructorError) { struct BrokenCalculatorHandler : CalculatorHandler { std::unique_ptr createAddition() override { throw std::runtime_error("Plus key is broken"); } }; ScopedServerInterfaceThread runner{ std::make_shared()}; apache::thrift::Client client( GuardedRequestChannel::newChannel( PooledRequestChannel::newChannel([&](folly::EventBase& evb) { return RocketClientChannel::newChannel( folly::AsyncSocket::UniquePtr( new folly::AsyncSocket(&evb, runner.getAddress()))); }))); const char* kExpectedErr = "apache::thrift::TApplicationException:" " Interaction constructor failed with std::runtime_error: Plus key is broken"; auto adder = client.createAddition(); #if FOLLY_HAS_COROUTINES folly::coro::blockingWait([&]() -> folly::coro::Task { auto t = co_await folly::coro::co_awaitTry(adder.co_accumulatePrimitive(1)); EXPECT_STREQ(t.exception().what().c_str(), kExpectedErr); auto t2 = co_await folly::coro::co_awaitTry(adder.co_getPrimitive()); EXPECT_STREQ(t.exception().what().c_str(), kExpectedErr); auto sum = co_await client.co_addPrimitive(20, 22); EXPECT_EQ(sum, 42); }()); #endif } TEST_F(GuardedRequestChannelTest, interactionMethodException) { struct ExceptionCalculatorHandler : apache::thrift::ServiceHandler { struct AdditionHandler : apache::thrift::ServiceHandler::AdditionIf { int acc_{0}; #if FOLLY_HAS_COROUTINES folly::coro::Task co_accumulatePrimitive(int32_t a) override { acc_ += a; co_yield folly::coro::co_error( std::runtime_error("Not Implemented Yet")); } #endif }; std::unique_ptr createAddition() override { return std::make_unique(); } }; ScopedServerInterfaceThread runner{ std::make_shared()}; apache::thrift::Client client( GuardedRequestChannel::newChannel( PooledRequestChannel::newChannel([&](folly::EventBase& evb) { return RocketClientChannel::newChannel( folly::AsyncSocket::UniquePtr( new folly::AsyncSocket(&evb, runner.getAddress()))); }))); const char* kExpectedErr = "apache::thrift::TApplicationException: std::runtime_error: Not Implemented Yet"; auto adder = client.createAddition(); #if FOLLY_HAS_COROUTINES folly::coro::blockingWait([&]() -> folly::coro::Task { auto t = co_await folly::coro::co_awaitTry(adder.co_accumulatePrimitive(1)); EXPECT_STREQ(t.exception().what().c_str(), kExpectedErr); }()); #endif }