/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef THRIFT_ASYNC_REQUESTCHANNEL_H_ #define THRIFT_ASYNC_REQUESTCHANNEL_H_ 1 #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace folly { class IOBuf; } namespace apache { namespace thrift { class StreamClientCallback; class SinkClientCallback; class RequestChannel; namespace detail { template struct RequestClientCallbackType {}; template <> struct RequestClientCallbackType { using Ptr = RequestClientCallback::Ptr; }; template <> struct RequestClientCallbackType { using Ptr = RequestClientCallback::Ptr; }; template <> struct RequestClientCallbackType { using Ptr = StreamClientCallback*; }; template <> struct RequestClientCallbackType { using Ptr = SinkClientCallback*; }; template struct const_if_lvalue_ref { using type = T; }; template struct const_if_lvalue_ref { using type = const T&; }; template using ChannelSendFunc = void (RequestChannel::*)( typename const_if_lvalue_ref::type&&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, typename RequestClientCallbackType::Ptr); } // namespace detail /** * RequestChannel defines an asynchronous API for request-based I/O. */ class RequestChannel : virtual public folly::DelayedDestruction { protected: ~RequestChannel() override {} public: /** * ReplyCallback will be invoked when the reply to this request is * received. TRequestChannel is responsible for associating requests with * responses, and invoking the correct ReplyCallback when a response * message is received. * * cb must not be null. */ template void sendRequestAsync( RpcOptions&&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr&&, typename apache::thrift::detail::RequestClientCallbackType::Ptr); /** * ReplyCallback will be invoked when the reply to this request is * received. TRequestChannel is responsible for associating requests with * responses, and invoking the correct ReplyCallback when a response * message is received. * * cb must not be null. */ virtual void sendRequestResponse( const RpcOptions&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, RequestClientCallback::Ptr); /* Similar to sendRequest, although replyReceived will never be called * * Null RequestCallback is allowed for oneway requests */ virtual void sendRequestNoResponse( const RpcOptions&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, RequestClientCallback::Ptr); /** * ReplyCallback will be invoked when the reply to this request is * received. RequestChannel is responsible for associating requests with * responses, and invoking the correct ReplyCallback when a response * message is received. A response to this request may contain a stream. * * cb must not be null. */ virtual void sendRequestStream( const RpcOptions& rpcOptions, MethodMetadata&&, SerializedRequest&&, std::shared_ptr header, StreamClientCallback* clientCallback); virtual void sendRequestSink( const RpcOptions& rpcOptions, MethodMetadata&&, SerializedRequest&&, std::shared_ptr header, SinkClientCallback* clientCallback); // Some channels can make use of rvalue RpcOptions as an optimization. virtual void sendRequestResponse( RpcOptions&&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, RequestClientCallback::Ptr); virtual void sendRequestNoResponse( RpcOptions&&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, RequestClientCallback::Ptr); virtual void sendRequestStream( RpcOptions&&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, StreamClientCallback*); virtual void sendRequestSink( RpcOptions&&, MethodMetadata&&, SerializedRequest&&, std::shared_ptr, SinkClientCallback*); virtual void setCloseCallback(CloseCallback*) = 0; virtual folly::EventBase* getEventBase() const = 0; virtual uint16_t getProtocolId() = 0; virtual void terminateInteraction(InteractionId id); // registers a new interaction with the channel // returns id of created interaction (always nonzero) virtual InteractionId createInteraction(ManagedStringView&& name); // registers an interaction with a nested channel // only some channels can be nested; the rest call terminate here virtual InteractionId registerInteraction( ManagedStringView&& name, int64_t id); using Ptr = std::unique_ptr; uint64_t getChecksumSamplingRate() const; protected: static InteractionId createInteractionId(int64_t id); static void releaseInteractionId(InteractionId&& id); void setChecksumSamplingRate(uint64_t samplingRate); private: uint64_t checksumSamplingRate_{0}; }; template class ClientBatonCallback : public RequestClientCallback { public: explicit ClientBatonCallback( ClientReceiveState* rs, folly::Executor* executor = {}) : rs_(rs), executor_(executor) {} template void waitUntilDone(folly::EventBase* evb, F&& sendF) { if (evb && (!evb->inRunningEventBaseThread() || !folly::fibers::onFiber())) { folly::fibers::runInMainContext([&] { sendF(); while (!doneBaton_.ready()) { evb->drive(); } }); } else { sendF(); // Check if it's ready to avoid unnecessarily preempting a fiber. if (!doneBaton_.ready()) { doneBaton_.wait(); } } } void waitUntilDone(folly::EventBase* evb) { waitUntilDone(evb, [] {}); } // This approach avoids an inner coroutine frame folly::fibers::Baton& co_waitUntilDone() { return doneBaton_; } void onResponse(ClientReceiveState&& rs) noexcept override { if (!oneWay) { assert(rs.hasResponseBuffer()); *rs_ = std::move(rs); } doneBaton_.post(); } void onResponseError(folly::exception_wrapper ex) noexcept override { *rs_ = ClientReceiveState(std::move(ex), nullptr); doneBaton_.post(); } bool isInlineSafe() const override { return true; } bool isSync() const override { return sync; } folly::Executor::KeepAlive<> getExecutor() const override { return executor_; } private: ClientReceiveState* rs_; folly::fibers::Baton doneBaton_; folly::Executor* executor_{}; }; template using ClientSyncCallback = ClientBatonCallback; template using ClientCoroCallback = ClientBatonCallback; StreamClientCallback* createStreamClientCallback( RequestClientCallback::Ptr requestCallback, const BufferOptions& bufferOptions); SinkClientCallback* createSinkClientCallback( RequestClientCallback::Ptr requestCallback); template SerializedRequest preprocessSendT( Protocol* prot, const apache::thrift::RpcOptions& rpcOptions, apache::thrift::ContextStack* ctx, apache::thrift::transport::THeader& header, folly::StringPiece methodName, folly::FunctionRef writefunc, folly::FunctionRef sizefunc, uint64_t checksumSamplingRate) { return folly::fibers::runInMainContext([&] { size_t bufSize = sizefunc(prot); folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); // Preallocate small buffer headroom for transports metadata & framing. constexpr size_t kHeadroomBytes = 128; auto buf = folly::IOBuf::create(kHeadroomBytes + bufSize); buf->advance(kHeadroomBytes); queue.append(std::move(buf)); prot->setOutput(&queue, bufSize); auto guard = folly::makeGuard([&] { prot->setOutput(nullptr); }); try { if (ctx) { ctx->preWrite(); } writefunc(prot); ::apache::thrift::SerializedMessage smsg; smsg.protocolType = prot->protocolType(); smsg.buffer = queue.front(); smsg.methodName = methodName; if (ctx) { ctx->onWriteData(smsg); ctx->postWrite(folly::to_narrow(queue.chainLength())); ctx->resetClientRequestContextHeader(); } } catch (const apache::thrift::TException&) { if (ctx) { ctx->handlerErrorWrapped( folly::exception_wrapper(std::current_exception())); } throw; } if (rpcOptions.getEnableChecksum() || (checksumSamplingRate && folly::Random::rand64(checksumSamplingRate) == 0)) { header.setCrc32c(apache::thrift::checksum::crc32c(*queue.front())); } return SerializedRequest(queue.move()); }); } namespace detail { template constexpr ChannelSendFunc getChannelSendFunc() { if constexpr (Kind == RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE) { return &RequestChannel::sendRequestResponse; } else if constexpr (Kind == RpcKind::SINGLE_REQUEST_NO_RESPONSE) { return &RequestChannel::sendRequestNoResponse; } else if constexpr (Kind == RpcKind::SINGLE_REQUEST_STREAMING_RESPONSE) { return &RequestChannel::sendRequestStream; } else { static_assert(Kind == RpcKind::SINK); return &RequestChannel::sendRequestSink; } } } // namespace detail template void RequestChannel::sendRequestAsync( RpcOptions&& rpcOptions, MethodMetadata&& methodMetadata, SerializedRequest&& request, std::shared_ptr&& header, typename apache::thrift::detail::RequestClientCallbackType::Ptr callback) { auto* eb = getEventBase(); if (!eb || eb->isInEventBaseThread()) { auto send = apache::thrift::detail::getChannelSendFunc(); (this->*send)( std::forward(rpcOptions), std::move(methodMetadata), std::move(request), std::move(header), std::move(callback)); } else { eb->runInEventBaseThread([this, rpcOptions = std::forward(rpcOptions), methodMetadata = std::move(methodMetadata), request = std::move(request), header = std::move(header), callback = std::move(callback)]() mutable { auto send = apache::thrift::detail::getChannelSendFunc(); (this->*send)( std::forward(rpcOptions), std::move(methodMetadata), std::move(request), std::move(header), std::move(callback)); }); } } template void clientSendT( Protocol* prot, RpcOptions&& rpcOptions, typename apache::thrift::detail::RequestClientCallbackType::Ptr callback, apache::thrift::ContextStack* ctx, std::shared_ptr&& header, RequestChannel* channel, apache::thrift::MethodMetadata&& methodMetadata, folly::FunctionRef writefunc, folly::FunctionRef sizefunc) { auto request = preprocessSendT( prot, rpcOptions, ctx, *header, methodMetadata.name_view(), writefunc, sizefunc, channel->getChecksumSamplingRate()); channel->sendRequestAsync( std::forward(rpcOptions), std::move(methodMetadata), std::move(request), std::move(header), std::move(callback)); } } // namespace thrift } // namespace apache #endif // #ifndef THRIFT_ASYNC_REQUESTCHANNEL_H_