/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace testing; using namespace apache::thrift; using namespace apache::thrift::test; using apache::thrift::transport::TTransportException; using folly::AsyncSocket; class TestServiceServerMock : public apache::thrift::ServiceHandler { public: MOCK_METHOD(int32_t, echoInt, (int32_t), (override)); MOCK_METHOD(void, noResponse, (int64_t), (override)); MOCK_METHOD( apache::thrift::ServerStream, range, (int32_t, int32_t), (override)); #if FOLLY_HAS_COROUTINES folly::coro::Task> co_sumSink() override { SinkConsumer sink; sink.consumer = [](folly::coro::AsyncGenerator gen) -> folly::coro::Task { int32_t res = 0; while (auto val = co_await gen.next()) { res += *val; } co_return res; }; sink.bufferSize = 10; co_return sink; } #endif }; class ReconnectingRequestChannelTest : public Test { public: folly::EventBase* eb{folly::EventBaseManager::get()->getEventBase()}; folly::ScopedBoundPort bound; std::shared_ptr handler{ std::make_shared()}; std::unique_ptr runner{ std::make_unique(handler)}; folly::SocketAddress up_addr{runner->getAddress()}; folly::SocketAddress dn_addr{bound.getAddress()}; uint32_t connection_count_ = 0; void runReconnect(TestServiceAsyncClient& client, bool testStreaming); #if FOLLY_HAS_COROUTINES folly::coro::Task> co_getClient() { auto executor = co_await folly::coro::co_current_executor; auto channel = PooledRequestChannel::newChannel( executor, ioThread_, [this](folly::EventBase& evb) { return ReconnectingRequestChannel::newChannel( evb, [this](folly::EventBase& evb) { auto socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb, up_addr)); return RocketClientChannel::newChannel(std::move(socket)); }); }); co_return std::make_shared(std::move(channel)); } std::shared_ptr ioThread_{ std::make_shared()}; #endif }; TEST_F(ReconnectingRequestChannelTest, ReconnectHeader) { auto channel = ReconnectingRequestChannel::newChannel( *eb, [this](folly::EventBase& eb) mutable { connection_count_++; return HeaderClientChannel::newChannel( AsyncSocket::newSocket(&eb, up_addr)); }); TestServiceAsyncClient client(std::move(channel)); runReconnect(client, false); } TEST_F(ReconnectingRequestChannelTest, ReconnectRocket) { auto channel = ReconnectingRequestChannel::newChannel( *eb, [this](folly::EventBase& eb) mutable { connection_count_++; return RocketClientChannel::newChannel(folly::AsyncSocket::UniquePtr( new folly::AsyncSocket(&eb, up_addr))); }); TestServiceAsyncClient client(std::move(channel)); runReconnect(client, true); } void ReconnectingRequestChannelTest::runReconnect( TestServiceAsyncClient& client, bool testStreaming) { EXPECT_CALL(*handler, echoInt(_)) .WillOnce(Return(1)) .WillOnce(Return(3)) .WillOnce(Return(4)); EXPECT_EQ(client.sync_echoInt(1), 1); EXPECT_EQ(connection_count_, 1); EXPECT_CALL(*handler, noResponse(_)); client.sync_noResponse(0); EXPECT_EQ(connection_count_, 1); auto checkStream = [](auto&& stream, int from, int to) { std::move(stream).subscribeInline([idx = from, to](auto nextTry) mutable { DCHECK(!nextTry.hasException()); if (!nextTry.hasValue()) { EXPECT_EQ(to, idx); } else { EXPECT_EQ(idx++, *nextTry); } }); }; if (testStreaming) { EXPECT_CALL(*handler, range(_, _)) .WillRepeatedly(Invoke( [](int32_t from, int32_t to) -> apache::thrift::ServerStream { auto [serverStream, publisher] = apache::thrift::ServerStream::createPublisher(); for (auto idx = from; idx < to; ++idx) { publisher.next(idx); } std::move(publisher).complete(); return std::move(serverStream); })); checkStream(client.sync_range(0, 1), 0, 1); EXPECT_EQ(connection_count_, 1); } // bounce the server runner = std::make_unique(handler); up_addr = runner->getAddress(); EXPECT_THROW(client.sync_echoInt(2), TTransportException); EXPECT_EQ(client.sync_echoInt(3), 3); EXPECT_EQ(connection_count_, 2); EXPECT_EQ(client.sync_echoInt(4), 4); EXPECT_EQ(connection_count_, 2); if (testStreaming) { checkStream(client.sync_range(0, 2), 0, 2); EXPECT_EQ(connection_count_, 2); checkStream(client.sync_range(4, 42), 4, 42); EXPECT_EQ(connection_count_, 2); } } #if FOLLY_HAS_COROUTINES TEST_F(ReconnectingRequestChannelTest, sinkReconnect) { folly::coro::blockingWait([&]() -> folly::coro::Task { auto client = co_await co_getClient(); auto consumer = co_await client->co_sumSink(); auto res = co_await consumer.sink([]() -> folly::coro::AsyncGenerator { for (int32_t i = 1; i <= 5; ++i) { co_yield std::move(i); } }()); EXPECT_EQ(res, 15); // bounce runner = std::make_unique(handler); up_addr = runner->getAddress(); // no exception here - the underlying impl is marked !good so the next // request ends up triggering the reconnect consumer = co_await client->co_sumSink(); res = co_await consumer.sink([]() -> folly::coro::AsyncGenerator { for (int32_t i = 1; i <= 5; ++i) { co_yield std::move(i); } }()); EXPECT_EQ(res, 15); }()); } #endif