/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include namespace apache { namespace thrift { template class FutureCallbackBase : public RequestCallback { public: explicit FutureCallbackBase( folly::Promise&& promise, std::shared_ptr channel = nullptr) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override {} void requestError(ClientReceiveState&& state) override { CHECK(state.isException()); promise_.setException(std::move(state.exception())); } protected: folly::Promise promise_; std::shared_ptr channel_; }; template class FutureCallback : public FutureCallbackBase { private: typedef folly::exception_wrapper (*Processor)(Result&, ClientReceiveState&); public: FutureCallback( folly::Promise&& promise, Processor processor, std::shared_ptr channel = nullptr) : FutureCallbackBase(std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor_(result, state); if (ew) { this->promise_.setException(ew); } else { this->promise_.setValue(std::move(result)); } } private: Processor processor_; }; template class HeaderFutureCallback : public FutureCallbackBase>> { private: using HeaderResult = std::pair>; typedef folly::exception_wrapper (*Processor)(Result&, ClientReceiveState&); Processor processor_; public: HeaderFutureCallback( folly::Promise&& promise, Processor processor, std::shared_ptr channel = nullptr) : FutureCallbackBase( std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor_(result, state); if (ew) { this->promise_.setException(ew); } else { this->promise_.setValue( std::make_pair(std::move(result), state.extractHeader())); } } }; template <> class HeaderFutureCallback : public FutureCallbackBase>> { private: using HeaderResult = std:: pair>; typedef folly::exception_wrapper (*Processor)(ClientReceiveState&); Processor processor_; public: HeaderFutureCallback( folly::Promise&& promise, Processor processor, std::shared_ptr channel = nullptr) : FutureCallbackBase( std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor_(state); if (ew) { promise_.setException(ew); } else { promise_.setValue(std::make_pair(folly::Unit(), state.extractHeader())); } } }; class OneWayFutureCallback : public FutureCallbackBase { public: explicit OneWayFutureCallback( folly::Promise&& promise, std::shared_ptr channel = nullptr) : FutureCallbackBase( std::move(promise), std::move(channel)) {} void requestSent() override { promise_.setValue(); } void replyReceived(ClientReceiveState&& /*state*/) override { CHECK(false); } }; template <> class FutureCallback : public FutureCallbackBase { private: typedef folly::exception_wrapper (*Processor)(ClientReceiveState&); public: FutureCallback( folly::Promise&& promise, Processor processor, std::shared_ptr channel = nullptr) : FutureCallbackBase(std::move(promise), std::move(channel)), processor_(processor) {} void replyReceived(ClientReceiveState&& state) override { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor_(state); if (ew) { promise_.setException(ew); } else { promise_.setValue(); } } private: Processor processor_; }; class SemiFutureCallback : public RequestCallback { public: template using Processor = folly::exception_wrapper (*)(Result&, ClientReceiveState&); using ProcessorVoid = folly::exception_wrapper (*)(ClientReceiveState&); explicit SemiFutureCallback( folly::Promise&& promise, std::shared_ptr channel) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override {} void replyReceived(ClientReceiveState&& state) override { promise_.setValue(std::move(state)); } void requestError(ClientReceiveState&& state) override { promise_.setException(std::move(state.exception())); } bool isInlineSafe() const override { return true; } protected: folly::Promise promise_; std::shared_ptr channel_; }; class OneWaySemiFutureCallback : public RequestCallback { public: OneWaySemiFutureCallback( folly::Promise&& promise, std::shared_ptr channel) : promise_(std::move(promise)), channel_(std::move(channel)) {} void requestSent() override { promise_.setValue(); } void replyReceived(ClientReceiveState&&) override { CHECK(false); } void requestError(ClientReceiveState&& state) override { promise_.setException(std::move(state.exception())); } bool isInlineSafe() const override { return true; } protected: folly::Promise promise_; std::shared_ptr channel_; }; template std::pair, folly::SemiFuture> makeSemiFutureCallback( SemiFutureCallback::Processor processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor(result, state); if (ew) { ew.throw_exception(); } return result; })}; } inline std:: pair, folly::SemiFuture> makeSemiFutureCallback( SemiFutureCallback::ProcessorVoid processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor(state); if (ew) { ew.throw_exception(); } })}; } template std::pair< std::unique_ptr, folly::SemiFuture< std::pair>>> makeHeaderSemiFutureCallback( SemiFutureCallback::Processor processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); Result result; auto ew = processor(result, state); if (ew) { ew.throw_exception(); } return std::make_pair(std::move(result), state.extractHeader()); })}; } inline std::pair< std::unique_ptr, folly::SemiFuture>>> makeHeaderSemiFutureCallback( SemiFutureCallback::ProcessorVoid processor, std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future).deferValue([processor](ClientReceiveState&& state) { CHECK(!state.isException()); CHECK(state.hasResponseBuffer()); auto ew = processor(state); if (ew) { ew.throw_exception(); } return std::make_pair(folly::unit, state.extractHeader()); })}; } inline std::pair< std::unique_ptr, folly::SemiFuture> makeOneWaySemiFutureCallback( std::shared_ptr channel) { folly::Promise promise; auto future = promise.getSemiFuture(); return { std::make_unique( std::move(promise), std::move(channel)), std::move(future)}; } template class CancellableRequestClientCallback : public RequestClientCallback { CancellableRequestClientCallback( RequestClientCallback* wrapped, std::shared_ptr channel) : callback_(wrapped), channel_(std::move(channel)) { DCHECK(wrapped->isInlineSafe()); } public: static std::unique_ptr create( RequestClientCallback* wrapped, std::shared_ptr channel) { return std::unique_ptr( new CancellableRequestClientCallback(wrapped, std::move(channel))); } static void cancel(std::unique_ptr cb) { cb.release()->onResponseError( folly::make_exception_wrapper()); } void onResponse(ClientReceiveState&& state) noexcept override { if (auto callback = callback_.exchange(nullptr, std::memory_order_acq_rel)) { callback->onResponse(std::move(state)); } else { delete this; } } void onResponseError(folly::exception_wrapper ew) noexcept override { if (auto callback = callback_.exchange(nullptr, std::memory_order_acq_rel)) { callback->onResponseError(std::move(ew)); } else { delete this; } } bool isInlineSafe() const override { return true; } private: std::atomic callback_; std::shared_ptr channel_; }; } // namespace thrift } // namespace apache