/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include namespace apache { namespace thrift { namespace rocket { class RocketClient; enum class StreamChannelStatus { Alive, Complete, ContractViolation }; class StreamChannelStatusResponse { public: /* implicit */ StreamChannelStatusResponse(StreamChannelStatus status) : status_(status) {} StreamChannelStatusResponse(StreamChannelStatus status, std::string errorMsg) : status_(status), errorMsg_(std::move(errorMsg)) {} StreamChannelStatus getStatus() const { return status_; } std::string getErrorMsg() && { return std::move(errorMsg_).value_or("Unknown error"); } private: StreamChannelStatus status_; std::optional errorMsg_; }; class RocketStreamServerCallback : public StreamServerCallback { public: RocketStreamServerCallback( rocket::StreamId streamId, rocket::RocketClient& client, StreamClientCallback& clientCallback) : client_(client), clientCallback_(&clientCallback), streamId_(streamId) {} bool onStreamRequestN(uint64_t tokens) override; void onStreamCancel() override; bool onSinkHeaders(HeadersPayload&& payload) override; void resetClientCallback(StreamClientCallback& clientCallback) override { clientCallback_ = &clientCallback; } bool onInitialPayload(FirstResponsePayload&&, folly::EventBase*); void onInitialError(folly::exception_wrapper ew); StreamChannelStatusResponse onStreamPayload(StreamPayload&&); StreamChannelStatusResponse onStreamFinalPayload(StreamPayload&&); StreamChannelStatusResponse onStreamComplete(); StreamChannelStatusResponse onStreamError(folly::exception_wrapper); void onStreamHeaders(HeadersPayload&&); rocket::StreamId streamId() const noexcept { return streamId_; } protected: rocket::RocketClient& client_; private: StreamClientCallback* clientCallback_; rocket::StreamId streamId_; }; class RocketStreamServerCallbackWithChunkTimeout : public RocketStreamServerCallback { public: RocketStreamServerCallbackWithChunkTimeout( rocket::StreamId streamId, rocket::RocketClient& client, StreamClientCallback& clientCallback, std::chrono::milliseconds chunkTimeout, uint64_t initialCredits) : RocketStreamServerCallback(streamId, client, clientCallback), chunkTimeout_(chunkTimeout), credits_(initialCredits) {} bool onStreamRequestN(uint64_t tokens) override; bool onInitialPayload(FirstResponsePayload&&, folly::EventBase*); StreamChannelStatusResponse onStreamPayload(StreamPayload&&); void timeoutExpired() noexcept; private: void scheduleTimeout(); void cancelTimeout(); const std::chrono::milliseconds chunkTimeout_; uint64_t credits_{0}; std::unique_ptr timeout_; }; class RocketSinkServerCallback : public SinkServerCallback { public: RocketSinkServerCallback( rocket::StreamId streamId, rocket::RocketClient& client, SinkClientCallback& clientCallback, std::unique_ptr compressionConfig) : client_(client), clientCallback_(&clientCallback), streamId_(streamId), compressionConfig_(std::move(compressionConfig)) {} bool onSinkNext(StreamPayload&&) override; void onSinkError(folly::exception_wrapper) override; bool onSinkComplete() override; void resetClientCallback(SinkClientCallback& clientCallback) override { clientCallback_ = &clientCallback; } bool onInitialPayload(FirstResponsePayload&&, folly::EventBase*); void onInitialError(folly::exception_wrapper); StreamChannelStatusResponse onFinalResponse(StreamPayload&&); StreamChannelStatusResponse onFinalResponseError(folly::exception_wrapper); StreamChannelStatusResponse onSinkRequestN(uint64_t tokens); rocket::StreamId streamId() const noexcept { return streamId_; } private: rocket::RocketClient& client_; SinkClientCallback* clientCallback_; rocket::StreamId streamId_; enum class State { BothOpen, StreamOpen }; State state_{State::BothOpen}; std::unique_ptr compressionConfig_; }; } // namespace rocket } // namespace thrift } // namespace apache