/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #ifndef _WIN32 #include #endif #include #include #include #include #include #include #include #include #include #include #include THRIFT_FLAG_DEFINE_bool(rocket_set_idle_connection_timeout, true); namespace apache { namespace thrift { namespace detail { #define THRIFT_DETAIL_REGISTER_SERVER_EXTENSION_DEFAULT(FUNC) \ THRIFT_PLUGGABLE_FUNC_REGISTER( \ std::unique_ptr, \ FUNC, \ apache::thrift::ThriftServer&) { \ return {}; \ } THRIFT_DETAIL_REGISTER_SERVER_EXTENSION_DEFAULT( createRocketDebugSetupFrameHandler) THRIFT_DETAIL_REGISTER_SERVER_EXTENSION_DEFAULT( createRocketMonitoringSetupFrameHandler) THRIFT_DETAIL_REGISTER_SERVER_EXTENSION_DEFAULT( createRocketProfilingSetupFrameHandler) #undef THRIFT_DETAIL_REGISTER_SERVER_EXTENSION_DEFAULT THRIFT_PLUGGABLE_FUNC_REGISTER( std::unique_ptr, createSecuritySetupFrameInterceptor, apache::thrift::ThriftServer&) { return {}; } } // namespace detail RocketRoutingHandler::RocketRoutingHandler(ThriftServer& server) : streamMetricCallback_( detail::ThriftServerInternals(server).getStreamMetricCallback()) { auto addSetupFramehandler = [&](auto&& handlerFactory) { if (auto handler = handlerFactory(server)) { setupFrameHandlers_.push_back(std::move(handler)); } }; addSetupFramehandler(detail::createRocketDebugSetupFrameHandler); addSetupFramehandler(detail::createRocketMonitoringSetupFrameHandler); addSetupFramehandler(detail::createRocketProfilingSetupFrameHandler); auto addSetupFrameInterceptor = [&](auto&& handlerFactory) { if (auto handler = handlerFactory(server)) { setupFrameInterceptors_.push_back(std::move(handler)); } }; addSetupFrameInterceptor(detail::createSecuritySetupFrameInterceptor); } RocketRoutingHandler::~RocketRoutingHandler() { stopListening(); } void RocketRoutingHandler::stopListening() { listening_ = false; } bool RocketRoutingHandler::canAcceptConnection( const std::vector& bytes, const wangle::TransportInfo&) { class FrameHeader { public: /* * Sample start of an Rsocket frame (version 1.0) in Octal: * 0x0000 2800 0000 0004 0000 0100 00.... * Rsocket frame length - 24 bits * StreamId - 32 bits * Frame type - 6 bits * Flags - 10 bits * Major version - 16 bits * Minor version - 16 bits */ static uint16_t getMajorVersion(const std::vector& bytes) { return bytes[9] << 8 | bytes[10]; } static uint16_t getMinorVersion(const std::vector& bytes) { return bytes[11] << 8 | bytes[12]; } static rocket::FrameType getType(const std::vector& bytes) { return rocket::FrameType(bytes[7] >> 2); } }; return listening_ && // This only supports Rsocket protocol version 1.0 FrameHeader::getMajorVersion(bytes) == 1 && FrameHeader::getMinorVersion(bytes) == 0 && FrameHeader::getType(bytes) == rocket::FrameType::SETUP; } bool RocketRoutingHandler::canAcceptEncryptedConnection( const std::string& protocolName) { return listening_ && protocolName == "rs"; } void RocketRoutingHandler::handleConnection( wangle::ConnectionManager* connectionManager, folly::AsyncTransport::UniquePtr sock, const folly::SocketAddress* address, const wangle::TransportInfo& tinfo, std::shared_ptr worker) { if (!listening_) { return; } auto* const server = worker->getServer(); rocket::RocketServerConnection::Config cfg; cfg.socketWriteTimeout = server->getSocketWriteTimeout(); cfg.streamStarvationTimeout = server->getStreamExpireTime(); cfg.writeBatchingInterval = server->getWriteBatchingInterval(); cfg.writeBatchingSize = server->getWriteBatchingSize(); cfg.writeBatchingByteSize = server->getWriteBatchingByteSize(); cfg.egressBufferBackpressureThreshold = server->getEgressBufferBackpressureThreshold(); cfg.egressBufferBackpressureRecoveryFactor = server->getEgressBufferRecoveryFactor(); cfg.socketOptions = &server->getPerConnectionSocketOptions(); cfg.parserAllocator = server->getCustomAllocatorForParser(); const std::string& securotyProtocol = sock->getSecurityProtocol(); auto* const sockPtr = sock.get(); auto* const connection = new rocket::RocketServerConnection( std::move(sock), std::make_unique( worker, *address, sockPtr, setupFrameHandlers_, setupFrameInterceptors_), worker->getIngressMemoryTracker(), worker->getEgressMemoryTracker(), streamMetricCallback_, cfg); onConnection(*connection); connectionManager->addConnection( connection, THRIFT_FLAG(rocket_set_idle_connection_timeout), /* connectionAgeTimeout */ true); if (auto* observer = server->getObserver()) { observer->connAccepted( tinfo, server::TServerObserver::ConnectionInfo( reinterpret_cast(sockPtr), securotyProtocol)); observer->activeConnections( connectionManager->getNumConnections() * server->getNumIOWorkerThreads()); } } } // namespace thrift } // namespace apache