/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace apache { namespace thrift { template class AsyncTestSetup : public TestSetup { protected: void SetUp() override { handler_ = std::make_shared(); setNumIOThreads(numIOThreads_); setNumWorkerThreads(numWorkerThreads_); setQueueTimeout(std::chrono::milliseconds(0)); setIdleTimeout(std::chrono::milliseconds(0)); setTaskExpireTime(std::chrono::milliseconds(0)); setStreamExpireTime(std::chrono::milliseconds(0)); server_ = createServer( std::make_shared>(handler_), serverPort_); } void TearDown() override { if (server_) { server_->cleanUp(); server_.reset(); handler_.reset(); } } template void connectToServer( folly::Function(Client&)> callMe) { folly::coro::blockingWait([this, &callMe]() -> folly::coro::Task { CHECK_GT(serverPort_, 0) << "Check if the server has started already"; folly::Executor* executor = co_await folly::coro::co_current_executor; auto channel = PooledRequestChannel::newChannel( executor, ioThread_, [&](folly::EventBase& evb) { auto channel = apache::thrift::RocketClientChannel::newChannel( folly::AsyncSocket::UniquePtr( new SocketT(&evb, "::1", serverPort_))); channel->setTimeout(500 /* ms */); return channel; }); Client client(std::move(channel)); co_await callMe(client); }()); } protected: int numIOThreads_{1}; int numWorkerThreads_{1}; uint16_t serverPort_{0}; std::shared_ptr ioThread_{ std::make_shared()}; std::unique_ptr server_; std::shared_ptr handler_; }; class DuplicateWriteSocket : public folly::AsyncSocket { public: using folly::AsyncSocket::AsyncSocket; void writeChain( WriteCallback* callback, std::unique_ptr&& buf, folly::WriteFlags flags = folly::WriteFlags::NONE) override { // first request sends setup frame, don't duplicate this payload if (firstWrite_) { firstWrite_ = false; } else { buf->appendChain(buf->clone()); } folly::AsyncSocket::writeChain(callback, std::move(buf), flags); } private: bool firstWrite_{true}; }; } // namespace thrift } // namespace apache