/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace apache { namespace thrift { namespace rocket { class RocketClient; class RequestContextQueue; class RequestContext { private: template using payload_method_t = decltype(std::declval().payload()); public: class WriteSuccessCallback { public: virtual ~WriteSuccessCallback() = default; virtual void onWriteSuccess() noexcept = 0; }; enum class State : uint8_t { DEFERRED_INIT, /* still needs to be intialized with server version */ WRITE_NOT_SCHEDULED, WRITE_SCHEDULED, WRITE_SENDING, /* AsyncSocket::writeChain() called, but WriteCallback has not yet fired */ WRITE_SENT, /* Write to socket completed (possibly with error) */ COMPLETE, /* Terminal state. Result stored in responsePayload_ */ }; template RequestContext( Frame&& frame, RequestContextQueue& queue, SetupFrame* setupFrame = nullptr, WriteSuccessCallback* writeSuccessCallback = nullptr) : queue_(queue), streamId_(frame.streamId()), frameType_(Frame::frameType()), writeSuccessCallback_(writeSuccessCallback) { // Some `Frame`s lack a `payload()` method -- `RequestNFrame`, // `CancelFrame`, etc -- but those that do should have `.fds`. if constexpr (folly::is_detected::value) { fds = std::move(frame.payload().fds.dcheckToSendOrEmpty()); } serialize(std::forward(frame), setupFrame); } template RequestContext( InitFunc&& initFunc, int32_t serverVersion, StreamId streamId, RequestContextQueue& queue, WriteSuccessCallback* writeSuccessCallback = nullptr) : queue_(queue), streamId_(streamId), writeSuccessCallback_(writeSuccessCallback) { if (UNLIKELY(serverVersion == -1)) { deferredInit_ = std::forward(initFunc); state_ = State::DEFERRED_INIT; } else { std::tie(serializedFrame_, frameType_) = initFunc(serverVersion); } } RequestContext(const RequestContext&) = delete; RequestContext(RequestContext&&) = delete; RequestContext& operator=(const RequestContext&) = delete; RequestContext& operator=(RequestContext&&) = delete; // For REQUEST_RESPONSE contexts, where an immediate matching response is // expected FOLLY_NODISCARD folly::Try waitForResponse( std::chrono::milliseconds timeout); FOLLY_NODISCARD folly::Try getResponse() &&; // For request types for which an immediate matching response is not // necessarily expected, e.g., REQUEST_FNF and REQUEST_STREAM FOLLY_NODISCARD folly::Try waitForWriteToComplete(); void waitForWriteToCompleteSchedule(folly::fibers::Baton::Waiter* waiter); FOLLY_NODISCARD folly::Try waitForWriteToCompleteResult(); void setTimeoutInfo( folly::HHWheelTimer& timer, folly::HHWheelTimer::Callback& callback, std::chrono::milliseconds timeout) { timer_ = &timer; timeoutCallback_ = &callback; requestTimeout_ = timeout; } void scheduleTimeoutForResponse() { DCHECK(isRequestResponse()); // In some edge cases, response may arrive before write to socket finishes. if (state_ != State::COMPLETE && requestTimeout_ != std::chrono::milliseconds::zero()) { timer_->scheduleTimeout(timeoutCallback_, requestTimeout_); } } std::unique_ptr releaseSerializedChain() { DCHECK(serializedFrame_); return std::move(serializedFrame_); } size_t endOffsetInBatch() const { DCHECK_GT(endOffsetInBatch_, 0); return endOffsetInBatch_; } void setEndOffsetInBatch(ssize_t offset) { endOffsetInBatch_ = offset; } State state() const { return state_; } StreamId streamId() const { return streamId_; } bool isRequestResponse() const { return frameType_ == FrameType::REQUEST_RESPONSE; } void onPayloadFrame(PayloadFrame&& payloadFrame); void onErrorFrame(ErrorFrame&& errorFrame); void onWriteSuccess() noexcept; bool hasPartialPayload() const { return responsePayload_.hasValue(); } void initWithVersion(int32_t serverVersion) { if (!deferredInit_) { return; } DCHECK(state_ == State::DEFERRED_INIT); std::tie(serializedFrame_, frameType_) = deferredInit_(serverVersion); DCHECK(serializedFrame_ && frameType_ != FrameType::RESERVED); state_ = State::WRITE_NOT_SCHEDULED; } folly::SocketFds fds; protected: friend class RocketClient; void markLastInWriteBatch() { lastInWriteBatch_ = true; } private: RequestContextQueue& queue_; folly::SafeIntrusiveListHook queueHook_; std::unique_ptr serializedFrame_; ssize_t endOffsetInBatch_{}; StreamId streamId_; FrameType frameType_; State state_{State::WRITE_NOT_SCHEDULED}; bool lastInWriteBatch_{false}; bool isDummyEndOfBatchMarker_{false}; boost::intrusive::unordered_set_member_hook<> setHook_; folly::fibers::Baton baton_; std::chrono::milliseconds requestTimeout_{1000}; folly::HHWheelTimer* timer_{nullptr}; folly::HHWheelTimer::Callback* timeoutCallback_{nullptr}; folly::Try responsePayload_; WriteSuccessCallback* const writeSuccessCallback_{nullptr}; folly::Function, FrameType>(int32_t)> deferredInit_{nullptr}; template void serialize(Frame&& frame, SetupFrame* setupFrame) { DCHECK(!serializedFrame_); serializedFrame_ = std::move(frame).serialize(); if (UNLIKELY(setupFrame != nullptr)) { Serializer writer; std::move(*setupFrame).serialize(writer); auto setupBuffer = std::move(writer).move(); setupBuffer->prependChain(std::move(serializedFrame_)); serializedFrame_ = std::move(setupBuffer); } } explicit RequestContext(RequestContextQueue& queue) : queue_(queue), frameType_(FrameType::REQUEST_RESPONSE) {} static RequestContext& createDummyEndOfBatchMarker( RequestContextQueue& queue) { auto* rctx = new RequestContext(queue); rctx->lastInWriteBatch_ = true; rctx->isDummyEndOfBatchMarker_ = true; rctx->state_ = State::WRITE_SENDING; return *rctx; } struct Equal { bool operator()( const RequestContext& ctxa, const RequestContext& ctxb) const noexcept { return ctxa.streamId_ == ctxb.streamId_; } }; struct Hash { size_t operator()(const RequestContext& ctx) const noexcept { return std::hash()( static_cast(ctx.streamId_)); } }; public: using Queue = folly::CountedIntrusiveList; using UnorderedSet = boost::intrusive::unordered_set< RequestContext, boost::intrusive::member_hook< RequestContext, decltype(setHook_), &RequestContext::setHook_>, boost::intrusive::equal, boost::intrusive::hash>; private: friend class RequestContextQueue; }; } // namespace rocket } // namespace thrift } // namespace apache