/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace apache { namespace thrift { namespace rocket { namespace test { namespace { constexpr int32_t kClientVersion = 7; constexpr int32_t kServerVersion = 10; std::pair, std::unique_ptr> makeTestResponse( std::unique_ptr requestMetadata, std::unique_ptr requestData) { std::pair, std::unique_ptr> response; folly::StringPiece data(requestData->coalesce()); constexpr folly::StringPiece kMetadataEchoPrefix{"metadata_echo:"}; constexpr folly::StringPiece kDataEchoPrefix{"data_echo:"}; if (data.removePrefix("sleep_ms:")) { // Sleep, then echo back request. std::chrono::milliseconds sleepFor(folly::to(data)); std::this_thread::sleep_for(sleepFor); // sleep override } else if (data.removePrefix("error:")) { // Reply with a specific kind of error. } else if (data.startsWith(kMetadataEchoPrefix)) { // Reply with echoed metadata in the response payload. auto responseMetadata = requestData->clone(); responseMetadata->trimStart(kMetadataEchoPrefix.size()); response = std::make_pair(std::move(responseMetadata), std::move(requestData)); } else if (data.startsWith(kDataEchoPrefix)) { // Reply with echoed data in the response payload. auto responseData = requestData->clone(); responseData->trimStart(kDataEchoPrefix.size()); response = std::make_pair(std::move(requestMetadata), std::move(responseData)); } // If response payload is not set at this point, simply echo back what client // sent. if (!response.first && !response.second) { response = std::make_pair(std::move(requestMetadata), std::move(requestData)); } return response; } } // namespace rocket::SetupFrame RocketTestClient::makeTestSetupFrame( MetadataOpaqueMap md) { RequestSetupMetadata meta; meta.opaque_ref() = {}; *meta.opaque_ref() = std::move(md); meta.maxVersion_ref() = kClientVersion; CompactProtocolWriter compactProtocolWriter; folly::IOBufQueue paramQueue; compactProtocolWriter.setOutput(¶mQueue); meta.write(&compactProtocolWriter); // Serialize RocketClient's major/minor version (which is separate from the // rsocket protocol major/minor version) into setup metadata. auto buf = folly::IOBuf::createCombined( sizeof(int32_t) + meta.serializedSize(&compactProtocolWriter)); folly::IOBufQueue queue; queue.append(std::move(buf)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); // Serialize RocketClient's major/minor version (which is separate from the // rsocket protocol major/minor version) into setup metadata. appender.writeBE(0); // Thrift RocketClient major version appender.writeBE(1); // Thrift RocketClient minor version // Append serialized setup parameters to setup frame metadata appender.insert(paramQueue.move()); return rocket::SetupFrame( rocket::Payload::makeFromMetadataAndData(queue.move(), {}), false); } RocketTestClient::RocketTestClient(const folly::SocketAddress& serverAddr) : evb_(*evbThread_.getEventBase()), fm_(folly::fibers::getFiberManager(evb_)), serverAddr_(serverAddr) { connect(); } RocketTestClient::~RocketTestClient() { disconnect(); } folly::Try RocketTestClient::sendRequestResponseSync( Payload request, std::chrono::milliseconds timeout, RocketClient::WriteSuccessCallback* writeSuccessCallback) { folly::Try response; folly::fibers::Baton baton; evb_.runInEventBaseThread([&] { fm_.addTaskFinally( [&] { return client_->sendRequestResponseSync( std::move(request), timeout, writeSuccessCallback); }, [&](folly::Try>&& r) { response = collapseTry(std::move(r)); baton.post(); }); }); baton.wait(); return response; } folly::Try RocketTestClient::sendRequestFnfSync(Payload request) { folly::Try response; folly::fibers::Baton baton; evb_.runInEventBaseThread([&] { fm_.addTaskFinally( [&] { return client_->sendRequestFnfSync(std::move(request)); }, [&](folly::Try>&& r) { response = collapseTry(std::move(r)); baton.post(); }); }); baton.wait(); return response; } folly::Try> RocketTestClient::sendRequestStreamSync(Payload request) { constexpr std::chrono::milliseconds kFirstResponseTimeout{500}; constexpr std::chrono::milliseconds kChunkTimeout{500}; class TestStreamCallback final : public ::apache::thrift::detail::ClientStreamBridge:: FirstResponseCallback { public: TestStreamCallback( std::chrono::milliseconds chunkTimeout, folly::Promise> p) : chunkTimeout_(chunkTimeout), p_(std::move(p)) {} // ClientCallback interface void onFirstResponse( FirstResponsePayload&& firstPayload, ::apache::thrift::detail::ClientStreamBridge::ClientPtr clientStreamBridge) override { if (getRange(*firstPayload.payload) == "error:application") { p_.setException( folly::make_exception_wrapper( std::move(firstPayload.payload))); } else { p_.setValue(ClientBufferedStream( std::move(clientStreamBridge), [](folly::Try&& v) { if (v.hasValue()) { return folly::Try( Payload::makeFromData(std::move(v->payload))); } else if (v.hasException()) { return folly::Try(std::move(v.exception())); } else { return folly::Try(); } }, {100, 0})); } delete this; } void onFirstResponseError(folly::exception_wrapper ew) override { p_.setException(std::move(ew)); delete this; } private: std::chrono::milliseconds chunkTimeout_; folly::Promise> p_; }; folly::Promise> p; auto sf = p.getSemiFuture(); auto clientCallback = new TestStreamCallback(kChunkTimeout, std::move(p)); evb_.runInEventBaseThread([&] { fm_.addTask([&] { client_->sendRequestStream( std::move(request), kFirstResponseTimeout, kChunkTimeout, 0, ::apache::thrift::detail::ClientStreamBridge::create(clientCallback)); }); }); return folly::makeTryWith([&] { return std::move(sf).via(&folly::InlineExecutor::instance()).get(); }); } void RocketTestClient::sendRequestSink( SinkClientCallback* callback, Payload request) { evb_.runInEventBaseThread( [this, request = std::move(request), callback]() mutable { fm_.addTask([this, request = std::move(request), callback]() mutable { constexpr std::chrono::milliseconds kFirstResponseTimeout{500}; client_->sendRequestSink( std::move(request), kFirstResponseTimeout, callback); }); }); } void RocketTestClient::reconnect() { disconnect(); connect(); } void RocketTestClient::connect() { evb_.runInEventBaseThreadAndWait([this] { folly::AsyncSocket::UniquePtr socket( new folly::AsyncSocket(&evb_, serverAddr_)); client_ = RocketClient::create( evb_, std::move(socket), std::make_unique(makeTestSetupFrame())); }); } void RocketTestClient::disconnect() { evb_.runInEventBaseThread([client = std::move(client_)] {}); } void RocketTestClient::verifyVersion() { if (client_ && client_->getServerVersion() != -1) { EXPECT_EQ( std::min(kClientVersion, kServerVersion), client_->getServerVersion()); } } namespace { class RocketTestServerAcceptor final : public wangle::Acceptor { public: RocketTestServerAcceptor( folly::Function()> frameHandlerFactory, std::promise shutdownPromise) : Acceptor(wangle::ServerSocketConfig{}), frameHandlerFactory_(std::move(frameHandlerFactory)), shutdownPromise_(std::move(shutdownPromise)) {} ~RocketTestServerAcceptor() override { EXPECT_EQ(0, connections_); } void onNewConnection( folly::AsyncTransport::UniquePtr socket, const folly::SocketAddress*, const std::string&, wangle::SecureTransportType, const wangle::TransportInfo&) override { auto* connection = new RocketServerConnection( std::move(socket), frameHandlerFactory_(), memoryTracker_, // (ingress) memoryTracker_ // (egress) ); getConnectionManager()->addConnection(connection); } void onConnectionsDrained() override { shutdownPromise_.set_value(); } void onConnectionAdded(const wangle::ManagedConnection*) override { ++connections_; } void onConnectionRemoved(const wangle::ManagedConnection* conn) override { if (expectedRemainingStreams_ != folly::none) { if (auto rconn = dynamic_cast(conn)) { EXPECT_EQ(expectedRemainingStreams_, rconn->getNumStreams()); } } --connections_; } void setExpectedRemainingStreams(size_t size) { expectedRemainingStreams_ = size; } private: folly::Function()> frameHandlerFactory_; std::promise shutdownPromise_; size_t connections_{0}; folly::Optional expectedRemainingStreams_ = folly::none; MemoryTracker memoryTracker_; }; } // namespace class RocketTestServer::RocketTestServerHandler : public RocketServerHandler { public: explicit RocketTestServerHandler( folly::EventBase& ioEvb, const MetadataOpaqueMap& expectedSetupMetadata) : ioEvb_(ioEvb), expectedSetupMetadata_(expectedSetupMetadata) {} void handleSetupFrame( SetupFrame&& frame, RocketServerConnection& connection) final { folly::io::Cursor cursor(frame.payload().buffer()); // Validate Rocket protocol key uint32_t protocolKey; const bool success = cursor.tryReadBE(protocolKey); EXPECT_TRUE(success); EXPECT_TRUE( protocolKey == 1 || protocolKey == RpcMetadata_constants::kRocketProtocolKey() || frame.rocketMimeTypes()); if (protocolKey != 1 && protocolKey != RpcMetadata_constants::kRocketProtocolKey()) { cursor.retreat(4); } // Validate RequestSetupMetadata CompactProtocolReader reader; reader.setInput(cursor); RequestSetupMetadata meta; meta.read(&reader); EXPECT_EQ(reader.getCursorPosition(), frame.payload().metadataSize()); EXPECT_EQ(expectedSetupMetadata_, meta.opaque_ref().value_or({})); version_ = std::min(kServerVersion, meta.maxVersion_ref().value_or(0)); ServerPushMetadata serverMeta; serverMeta.set_setupResponse(); serverMeta.setupResponse_ref()->version_ref() = version_; CompactProtocolWriter compactProtocolWriter; folly::IOBufQueue queue; compactProtocolWriter.setOutput(&queue); serverMeta.write(&compactProtocolWriter); connection.sendMetadataPush(std::move(queue).move()); } void handleRequestResponseFrame( RequestResponseFrame&& frame, RocketServerFrameContext&& context) final { auto dam = splitMetadataAndData(frame.payload()); auto payload = std::move(frame.payload()); auto dataPiece = getRange(*dam.second); if (dataPiece.removePrefix("error:application")) { return context.sendError( RocketException( ErrorCode::APPLICATION_ERROR, "Application error occurred"), nullptr); } auto response = makeTestResponse(std::move(dam.first), std::move(dam.second)); auto responsePayload = Payload::makeFromMetadataAndData( std::move(response.first), std::move(response.second)); return context.sendPayload( std::move(responsePayload), Flags().next(true).complete(true), nullptr); } void handleRequestFnfFrame( RequestFnfFrame&&, RocketServerFrameContext&&) final {} void handleRequestStreamFrame( RequestStreamFrame&& frame, RocketServerFrameContext&&, RocketStreamClientCallback* clientCallback) final { class TestRocketStreamServerCallback final : public StreamServerCallback { public: TestRocketStreamServerCallback( StreamClientCallback* clientCallback, size_t n, size_t nEchoHeaders) : clientCallback_(clientCallback), n_(n), nEchoHeaders_(nEchoHeaders) {} bool onStreamRequestN(uint64_t tokens) override { while (tokens-- && i_++ < n_) { auto alive = clientCallback_->onStreamNext(StreamPayload{ folly::IOBuf::copyBuffer(folly::to(i_)), {}}); DCHECK(alive); } if (i_ == n_ && iEchoHeaders_ == nEchoHeaders_) { clientCallback_->onStreamComplete(); delete this; return false; } return true; } void onStreamCancel() override { delete this; } bool onSinkHeaders(HeadersPayload&& payload) override { auto metadata_ref = payload.payload.otherMetadata_ref(); EXPECT_TRUE(metadata_ref); if (metadata_ref) { EXPECT_EQ( folly::to(++iEchoHeaders_), (*metadata_ref)["expected_header"]); } auto alive = clientCallback_->onStreamHeaders(std::move(payload)); DCHECK(alive); if (i_ == n_ && iEchoHeaders_ == nEchoHeaders_) { clientCallback_->onStreamComplete(); delete this; return false; } return true; } void resetClientCallback(StreamClientCallback& clientCallback) override { clientCallback_ = &clientCallback; } private: StreamClientCallback* clientCallback_; size_t i_{0}; size_t iEchoHeaders_{0}; const size_t n_; const size_t nEchoHeaders_; }; std::unique_ptr buffer = std::move(frame.payload()).data()->clone(); folly::StringPiece data(buffer->coalesce()); if (data.removePrefix("error:application")) { clientCallback->onFirstResponseError( folly::make_exception_wrapper< thrift::detail::EncodedFirstResponseError>(FirstResponsePayload( folly::IOBuf::copyBuffer("error:application"), {}))); return; } const size_t nHeaders = data.removePrefix("generateheaders:") ? folly::to(data) : 0; const size_t nEchoHeaders = data.removePrefix("echoheaders:") ? folly::to(data) : 0; const size_t n = nHeaders || nEchoHeaders ? 0 : (data.removePrefix("generate:") ? folly::to(data) : 500); auto* serverCallback = new TestRocketStreamServerCallback(clientCallback, n, nEchoHeaders); { auto alive = clientCallback->onFirstResponse( FirstResponsePayload{ folly::IOBuf::copyBuffer(folly::to(0)), {}}, nullptr /* evb */, serverCallback); DCHECK(alive); } for (size_t i = 1; i <= nHeaders; ++i) { HeadersPayloadContent header; header.otherMetadata_ref() = { {"expected_header", folly::to(i)}}; auto alive = clientCallback->onStreamHeaders({std::move(header), {}}); DCHECK(alive); } if (n == 0 && nEchoHeaders == 0) { std::ignore = serverCallback->onStreamRequestN(0); } } void handleRequestChannelFrame( RequestChannelFrame&&, RocketServerFrameContext&&, RocketSinkClientCallback* clientCallback) final { apache::thrift::detail::SinkConsumerImpl impl{ [](folly::coro::AsyncGenerator&&> asyncGen) -> folly::coro::Task> { int current = 0; while (auto item = co_await asyncGen.next()) { auto payload = (*item).value(); auto data = folly::to( folly::StringPiece(payload.payload->coalesce())); EXPECT_EQ(current++, data); } co_return folly::Try(StreamPayload( folly::IOBuf::copyBuffer(folly::to(current)), {})); }, 10, std::chrono::milliseconds::zero(), {}}; auto serverCallback = apache::thrift::detail::ServerSinkBridge::create( std::move(impl), ioEvb_, clientCallback); clientCallback->onFirstResponse( FirstResponsePayload{ folly::IOBuf::copyBuffer(folly::to(0)), {}}, nullptr /* evb */, serverCallback.get()); folly::coro::co_invoke( &apache::thrift::detail::ServerSinkBridge::start, std::move(serverCallback)) .scheduleOn(threadManagerThread_.getEventBase()) .start(); } void connectionClosing() final {} int32_t getVersion() const final { return version_; } private: folly::EventBase& ioEvb_; const MetadataOpaqueMap& expectedSetupMetadata_; folly::ScopedEventBaseThread threadManagerThread_; int32_t version_{0}; }; RocketTestServer::RocketTestServer() : evb_(*ioThread_.getEventBase()), listeningSocket_(new folly::AsyncServerSocket(&evb_)) { std::promise shutdownPromise; shutdownFuture_ = shutdownPromise.get_future(); acceptor_ = std::make_unique( [this] { return std::make_unique( evb_, expectedSetupMetadata_); }, std::move(shutdownPromise)); start(); } RocketTestServer::~RocketTestServer() { stop(); } void RocketTestServer::start() { folly::via(&evb_, [this] { acceptor_->init(listeningSocket_.get(), &evb_); listeningSocket_->bind(0 /* bind to any port */); listeningSocket_->listen(128 /* tcpBacklog */); listeningSocket_->startAccepting(); }).wait(); } void RocketTestServer::stop() { // Ensure socket and acceptor are destroyed in EventBase thread folly::via(&evb_, [listeningSocket = std::move(listeningSocket_)] {}); // Wait for server to drain connections as gracefully as possible. shutdownFuture_.wait(); folly::via(&evb_, [acceptor = std::move(acceptor_)] {}); } uint16_t RocketTestServer::getListeningPort() const { return listeningSocket_->getAddress().getPort(); } wangle::ConnectionManager* RocketTestServer::getConnectionManager() const { return acceptor_->getConnectionManager(); } void RocketTestServer::setExpectedRemainingStreams(size_t n) { if (auto acceptor = dynamic_cast(acceptor_.get())) { acceptor->setExpectedRemainingStreams(n); } } void RocketTestServer::setExpectedSetupMetadata( MetadataOpaqueMap md) { expectedSetupMetadata_ = std::move(md); } } // namespace test } // namespace rocket } // namespace thrift } // namespace apache