/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include namespace apache::thrift { namespace detail { template class FutureCallbackHelper { public: using PromiseResult = folly::Expected< std::pair, std::pair>; static Result extractResult(PromiseResult&& result) { if (result.hasValue()) { return std::move(std::move(result).value().first); } std::move(result).error().first.throw_exception(); } using CallbackProcessorType = std::conditional_t< std::is_same_v, folly::exception_wrapper(ClientReceiveState&), folly::exception_wrapper(Result&, ClientReceiveState&)>; static void invokeCallbackProcessor( CallbackProcessorType& processor, folly::exception_wrapper& ew, Result& result, ClientReceiveState& state) { if constexpr (std::is_same_v) { ew = processor(state); } else { ew = processor(result, state); } } static PromiseResult makeResult(Result&& result, ClientReceiveState&& state) { return PromiseResult(std::pair{std::move(result), std::move(state)}); } static folly::Unexpected> makeError( folly::exception_wrapper&& ew, ClientReceiveState&& state) { return folly::makeUnexpected(std::pair{std::move(ew), std::move(state)}); } }; } // namespace detail template class FutureCallbackBase : public RequestCallback { public: using Helper = detail::FutureCallbackHelper; explicit FutureCallbackBase( folly::Promise&& promise, std::shared_ptr channel = nullptr) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override {} void requestError(ClientReceiveState&& state) override { CHECK(state.isException()); folly::exception_wrapper& ew = state.exception(); promise_.setValue(Helper::makeError(std::move(ew), std::move(state))); } protected: folly::Promise promise_; std::shared_ptr channel_; }; template class FutureCallback : public FutureCallbackBase { private: using Helper = typename FutureCallbackBase::Helper; using Processor = typename Helper::CallbackProcessorType; public: FutureCallback( folly::Promise&& promise, Processor& processor, std::shared_ptr channel = nullptr) : FutureCallbackBase(std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; folly::exception_wrapper ew; Helper::invokeCallbackProcessor(processor_, ew, result, state); if (ew) { this->promise_.setValue( Helper::makeError(std::move(ew), std::move(state))); } else { this->promise_.setValue( Helper::makeResult(std::move(result), std::move(state))); } } private: Processor& processor_; }; template class HeaderFutureCallback : public FutureCallbackBase>> { private: using HeaderResult = std::pair>; using Helper = detail::FutureCallbackHelper; using InnerHelper = detail::FutureCallbackHelper; using Processor = typename InnerHelper::CallbackProcessorType; Processor& processor_; public: HeaderFutureCallback( folly::Promise&& promise, Processor& processor, std::shared_ptr channel = nullptr) : FutureCallbackBase( std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; folly::exception_wrapper ew; InnerHelper::invokeCallbackProcessor(processor_, ew, result, state); if (ew) { this->promise_.setValue( Helper::makeError(std::move(ew), std::move(state))); } else { auto header = state.extractHeader(); this->promise_.setValue(Helper::makeResult( std::pair{std::move(result), std::move(header)}, std::move(state))); } } }; class OneWayFutureCallback : public FutureCallbackBase { private: using Helper = detail::FutureCallbackHelper; public: explicit OneWayFutureCallback( folly::Promise&& promise, std::shared_ptr channel = nullptr) : FutureCallbackBase( std::move(promise), std::move(channel)) {} void requestSent() override { promise_.setValue(Helper::makeResult(folly::Unit(), ClientReceiveState())); } void replyReceived(ClientReceiveState&& /*state*/) override { CHECK(false); } }; class SemiFutureCallback : public RequestCallback { public: template using Processor = folly::exception_wrapper (*)(Result&, ClientReceiveState&); using ProcessorVoid = folly::exception_wrapper (*)(ClientReceiveState&); explicit SemiFutureCallback( folly::Promise&& promise, std::shared_ptr channel) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override {} void replyReceived(ClientReceiveState&& state) override { promise_.setValue(std::move(state)); } void requestError(ClientReceiveState&& state) override { promise_.setException(std::move(state.exception())); } bool isInlineSafe() const override { return true; } protected: folly::Promise promise_; std::shared_ptr channel_; }; class OneWaySemiFutureCallback : public RequestCallback { public: OneWaySemiFutureCallback( folly::Promise&& promise, std::shared_ptr channel) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override { promise_.setValue(); } void replyReceived(ClientReceiveState&&) override { CHECK(false); } void requestError(ClientReceiveState&& state) override { promise_.setException(std::move(state.exception())); } bool isInlineSafe() const override { return true; } protected: folly::Promise promise_; std::shared_ptr channel_; }; template std::pair, folly::SemiFuture> makeSemiFutureCallback( SemiFutureCallback::Processor processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor(result, state); if (ew) { ew.throw_exception(); } return result; })}; } inline std:: pair, folly::SemiFuture> makeSemiFutureCallback( SemiFutureCallback::ProcessorVoid processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor(state); if (ew) { ew.throw_exception(); } })}; } template std::pair< std::unique_ptr, folly::SemiFuture< std::pair>>> makeHeaderSemiFutureCallback( SemiFutureCallback::Processor processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor(result, state); if (ew) { ew.throw_exception(); } return std::make_pair(std::move(result), state.extractHeader()); })}; } inline std::pair< std::unique_ptr, folly::SemiFuture>>> makeHeaderSemiFutureCallback( SemiFutureCallback::ProcessorVoid processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor(state); if (ew) { ew.throw_exception(); } return std::make_pair(folly::unit, state.extractHeader()); })}; } inline std::pair< std::unique_ptr, folly::SemiFuture> makeOneWaySemiFutureCallback( std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future)}; } template class CancellableRequestClientCallback : public RequestClientCallback { CancellableRequestClientCallback( RequestClientCallback* wrapped, std::shared_ptr channel) : callback_(wrapped), channel_(std::move(channel)) { DCHECK(wrapped->isInlineSafe()); } public: static std::unique_ptr create( RequestClientCallback* wrapped, std::shared_ptr channel) { return std::unique_ptr( new CancellableRequestClientCallback(wrapped, std::move(channel))); } static void cancel(std::unique_ptr cb) { cb.release()->onResponseError( folly::make_exception_wrapper()); } void onResponse(ClientReceiveState&& state) noexcept override { if (auto callback = callback_.exchange(nullptr, std::memory_order_acq_rel)) { callback->onResponse(std::move(state)); } else { delete this; } } void onResponseError(folly::exception_wrapper ew) noexcept override { if (auto callback = callback_.exchange(nullptr, std::memory_order_acq_rel)) { callback->onResponseError(std::move(ew)); } else { delete this; } } bool isInlineSafe() const override { return true; } private: std::atomic callback_; std::shared_ptr channel_; }; } // namespace apache::thrift