/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include namespace apache { namespace thrift { namespace detail { namespace twowaybridge_detail { template using Queue = folly::channels::detail::Queue; template class QueueWithTailPtr : public Queue { public: QueueWithTailPtr() = default; template QueueWithTailPtr(Queue&& queue, F&& visitor) : Queue(std::move(queue)) { for (auto* node = Queue::head_; node; node = node->next) { visitor(node->value); tail_ = node; } } void append(QueueWithTailPtr&& other) { if (!other.head_) { return; } if (!Queue::head_) { Queue::head_ = std::exchange(other.head_, nullptr); } else { tail_->next = std::exchange(other.head_, nullptr); } tail_ = other.tail_; } private: // holds invalid pointer if head_ is null typename Queue::Node* tail_; }; template using AtomicQueue = folly::channels::detail::AtomicQueue; // queue with no consumers template class AtomicQueueOrPtr { public: using MessageQueue = Queue; AtomicQueueOrPtr() {} ~AtomicQueueOrPtr() { auto storage = storage_.load(std::memory_order_relaxed); auto type = static_cast(storage & kTypeMask); auto ptr = storage & kPointerMask; switch (type) { case Type::EMPTY: case Type::CLOSED: return; case Type::TAIL: MessageQueue::fromReversed( reinterpret_cast(ptr)); return; default: folly::assume_unreachable(); } } AtomicQueueOrPtr(const AtomicQueueOrPtr&) = delete; AtomicQueueOrPtr& operator=(const AtomicQueueOrPtr&) = delete; // returns closed payload and does not move from message on failure Value* pushOrGetClosedPayload(Message&& message) { auto storage = storage_.load(std::memory_order_acquire); if (static_cast(storage & kTypeMask) == Type::CLOSED) { return closedPayload_; } std::unique_ptr node( new typename MessageQueue::Node(std::move(message))); assert(!(reinterpret_cast(node.get()) & kTypeMask)); while (true) { auto type = static_cast(storage & kTypeMask); auto ptr = storage & kPointerMask; switch (type) { case Type::EMPTY: case Type::TAIL: node->next = reinterpret_cast(ptr); if (storage_.compare_exchange_weak( storage, reinterpret_cast(node.get()) | static_cast(Type::TAIL), std::memory_order_release, std::memory_order_acquire)) { node.release(); return nullptr; } break; case Type::CLOSED: message = std::move(node->value); return closedPayload_; default: folly::assume_unreachable(); } } } MessageQueue closeOrGetMessages(Value* payload) { assert(payload); // nullptr is used as a sentinel // this is only read if the compare_exchange succeeds closedPayload_ = payload; while (true) { auto storage = storage_.exchange( static_cast(Type::EMPTY), std::memory_order_acquire); auto type = static_cast(storage & kTypeMask); auto ptr = storage & kPointerMask; switch (type) { case Type::TAIL: return MessageQueue::fromReversed( reinterpret_cast(ptr)); case Type::EMPTY: if (storage_.compare_exchange_weak( storage, static_cast(Type::CLOSED), std::memory_order_release, std::memory_order_relaxed)) { return MessageQueue(); } break; case Type::CLOSED: default: folly::assume_unreachable(); } } } bool isClosed() const { return static_cast(storage_ & kTypeMask) == Type::CLOSED; } private: enum class Type : intptr_t { EMPTY = 0, TAIL = 1, CLOSED = 2 }; static constexpr intptr_t kTypeMask = 3; static constexpr intptr_t kPointerMask = ~kTypeMask; // These can be combined if the platform requires Value to be 8-byte aligned. // Most platforms don't require that for functions. // A workaround is to make that function a member of an aligned struct // and pass in the address of the struct, but that is not necessarily a win // because of the runtime indirection cost. std::atomic storage_{0}; Value* closedPayload_{nullptr}; }; } // namespace twowaybridge_detail template < typename ClientConsumer, typename ClientMessage, typename ServerConsumer, typename ServerMessage, typename Derived> class TwoWayBridge { using ClientAtomicQueue = twowaybridge_detail::AtomicQueue; using ServerAtomicQueue = twowaybridge_detail::AtomicQueue; public: using ClientQueue = twowaybridge_detail::Queue; using ServerQueue = twowaybridge_detail::Queue; using ClientQueueWithTailPtr = twowaybridge_detail::QueueWithTailPtr; struct Deleter { void operator()(Derived* ptr) { ptr->decref(); } }; using Ptr = std::unique_ptr; Ptr copy() { auto refCount = refCount_.fetch_add(1, std::memory_order_relaxed); DCHECK(refCount > 0); return Ptr(derived()); } protected: TwoWayBridge() = default; // These should only be called from the client thread void clientPush(ServerMessage&& value) { serverQueue_.push(std::move(value)); } bool clientWait(ClientConsumer* consumer) { return clientQueue_.wait(consumer); } ClientConsumer* cancelClientWait() { return clientQueue_.cancelCallback(); } ClientQueue clientGetMessages() { return clientQueue_.getMessages(); } void clientClose() { clientQueue_.close(); } bool isClientClosed() { return clientQueue_.isClosed(); } // These should only be called from the server thread void serverPush(ClientMessage&& value) { clientQueue_.push(std::move(value)); } bool serverWait(ServerConsumer* consumer) { return serverQueue_.wait(consumer); } ServerConsumer* cancelServerWait() { return serverQueue_.cancelCallback(); } ServerQueue serverGetMessages() { return serverQueue_.getMessages(); } void serverClose() { serverQueue_.close(); } bool isServerClosed() { return serverQueue_.isClosed(); } private: void decref() { if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { delete derived(); } } Derived* derived() { return static_cast(this); } ClientAtomicQueue clientQueue_; ServerAtomicQueue serverQueue_; std::atomic refCount_{1}; }; } // namespace detail } // namespace thrift } // namespace apache