/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include using namespace apache::thrift; namespace { /** * In this test, we are interested in interaction events. */ class TestEventHandler : public TProcessorEventHandler { // An interaction ID must be unique per connection using UniqueInteractionId = std::pair; struct EventHandlerContext { const Cpp2RequestContext* requestContext; folly::SocketAddress peerAddress; explicit EventHandlerContext(const TConnectionContext* connContext) { CHECK(connContext); CHECK(connContext->getPeerAddress()); requestContext = dynamic_cast(connContext); peerAddress = *connContext->getPeerAddress(); } }; public: void* getServiceContext( const char* service_name, const char* fn_name, TConnectionContext* ctx) override { LOG(INFO) << fmt::format( "getServiceContext(\"{}\", \"{}\")", service_name, fn_name); return new EventHandlerContext(ctx); } void freeContext(void* _ctx, const char* fn_name) override { LOG(INFO) << fmt::format("freeContext(\"{}\")", fn_name); ASSERT_TRUE(_ctx); delete static_cast(_ctx); } void onReadData( void* _ctx, const char* fn_name, const SerializedMessage&) override { LOG(INFO) << fmt::format("onReadData(\"{}\")", fn_name); ASSERT_TRUE(_ctx); const auto* ctx = static_cast(_ctx); const auto* req = ctx->requestContext; ASSERT_TRUE(req); ASSERT_FALSE(ctx->peerAddress.empty()); ASSERT_GT(req->getInteractionId(), 0); UniqueInteractionId uniqueId{ctx->peerAddress, req->getInteractionId()}; if (req->getInteractionCreate()) { auto [_, added] = ids_.wlock()->emplace(uniqueId); ASSERT_TRUE(added); } else { ASSERT_TRUE(ids_.rlock()->count(uniqueId)); } if (std::string_view{"Calculator.Addition.noop"} == fn_name) { ASSERT_EQ(req->getRpcKind(), RpcKind::SINGLE_REQUEST_NO_RESPONSE); } else { ASSERT_EQ(req->getRpcKind(), RpcKind::SINGLE_REQUEST_SINGLE_RESPONSE); } } bool wantNonPerRequestCallbacks() const override { return wantNonPerRequestCallbacks_.load(); } void onInteractionTerminate(void* _ctx, int64_t id) override { LOG(INFO) << fmt::format("onInteractionTerminate({})", id); ASSERT_TRUE(_ctx); const auto* ctx = static_cast(_ctx); ASSERT_FALSE(ctx->requestContext); // no request context available here ASSERT_FALSE(ctx->peerAddress.empty()); ASSERT_GT(id, 0); UniqueInteractionId uniqueId{ctx->peerAddress, id}; ASSERT_EQ(1, ids_.wlock()->erase(uniqueId)); } size_t countInteractions() const { return ids_.rlock()->size(); } void setWantNonPerRequestCallbacks(bool val) { wantNonPerRequestCallbacks_.store(val); } private: folly::Synchronized> ids_; std::atomic_bool wantNonPerRequestCallbacks_{true}; }; class TestHandler : public ServiceHandler { public: class Addition : public AdditionIf { public: int32_t sync_getPrimitive() override { return acc; } void sync_accumulatePrimitive(int32_t val) override { acc += val; } private: int32_t acc{0}; }; TileAndResponse::AdditionIf, void> sync_newAddition() override { return {std::make_unique()}; } }; } // namespace TEST(TProcessorEventHandlerTest, BasicInteraction) { auto eventHandler = std::make_shared(); TProcessorBase::addProcessorEventHandler(eventHandler); { ScopedServerInterfaceThread runner(std::make_shared()); auto client = runner.newClient>(); auto add = client->sync_newAddition(); add.sync_accumulatePrimitive(7); EXPECT_EQ(add.sync_getPrimitive(), 7); add.sync_accumulatePrimitive(5); EXPECT_EQ(add.sync_getPrimitive(), 12); } EXPECT_EQ(eventHandler->countInteractions(), 0); eventHandler->setWantNonPerRequestCallbacks(false); { ScopedServerInterfaceThread runner(std::make_shared()); auto client = runner.newClient>(); client->sync_newAddition(); // destruct and trigger interaction termination } EXPECT_EQ(eventHandler->countInteractions(), 1) << "onInteractionTerminate shouldn't be called " "when wantNonPerRequestCallbacks is false"; } TEST(TProcessorEventHandlerTest, MultipleInteractions) { auto eventHandler = std::make_shared(); TProcessorBase::addProcessorEventHandler(eventHandler); { ScopedServerInterfaceThread runner(std::make_shared()); auto client = runner.newClient>(); { auto add = client->sync_newAddition(); add.sync_accumulatePrimitive(7); EXPECT_EQ(add.sync_getPrimitive(), 7); add.sync_accumulatePrimitive(5); EXPECT_EQ(add.sync_getPrimitive(), 12); } { auto add = client->sync_newAddition(); add.sync_accumulatePrimitive(3); EXPECT_EQ(add.sync_getPrimitive(), 3); add.sync_accumulatePrimitive(7); EXPECT_EQ(add.sync_getPrimitive(), 10); } } EXPECT_EQ(eventHandler->countInteractions(), 0); } TEST(TProcessorEventHandlerTest, MultipleConcurrentInteractions) { auto eventHandler = std::make_shared(); TProcessorBase::addProcessorEventHandler(eventHandler); { ScopedServerInterfaceThread runner(std::make_shared()); auto client = runner.newClient>(); /* 1 */ auto add1 = client->sync_newAddition(); add1.sync_accumulatePrimitive(7); EXPECT_EQ(add1.sync_getPrimitive(), 7); add1.sync_accumulatePrimitive(5); EXPECT_EQ(add1.sync_getPrimitive(), 12); /* 2 */ auto add2 = client->sync_newAddition(); add2.sync_accumulatePrimitive(3); EXPECT_EQ(add2.sync_getPrimitive(), 3); add2.sync_accumulatePrimitive(7); EXPECT_EQ(add2.sync_getPrimitive(), 10); EXPECT_EQ(eventHandler->countInteractions(), 2); } EXPECT_EQ(eventHandler->countInteractions(), 0); } TEST(TProcessorEventHandlerTest, ConnectionClose) { using namespace std::chrono; auto eventHandler = std::make_shared(); TProcessorBase::addProcessorEventHandler(eventHandler); // server ScopedServerInterfaceThread runner(std::make_shared()); // client folly::EventBase evb; auto socket = folly::AsyncSocket::newSocket(&evb, runner.getAddress()); auto channel = RocketClientChannel::newChannel(std::move(socket)); auto channelPtr = channel.get(); auto client = std::make_unique>( std::move(channel)); // 1st interaction auto add1 = client->sync_newAddition(); add1.sync_accumulatePrimitive(7); EXPECT_EQ(add1.sync_getPrimitive(), 7); add1.sync_accumulatePrimitive(5); EXPECT_EQ(add1.sync_getPrimitive(), 12); // 2nd interaction auto add2 = client->sync_newAddition(); add2.sync_accumulatePrimitive(3); EXPECT_EQ(add2.sync_getPrimitive(), 3); add2.sync_accumulatePrimitive(7); EXPECT_EQ(add2.sync_getPrimitive(), 10); EXPECT_EQ(eventHandler->countInteractions(), 2); // drop connection to the server channelPtr->closeNow(); // wait for termination events for (auto n = 10; n && eventHandler->countInteractions(); n--) { /* sleep override */ std::this_thread::sleep_for(1s); } EXPECT_EQ(eventHandler->countInteractions(), 0); } TEST(TProcessorEventHandlerTest, RpcKind) { auto eventHandler = std::make_shared(); TProcessorBase::addProcessorEventHandler(eventHandler); { ScopedServerInterfaceThread runner(std::make_shared()); auto client = runner.newClient>(); auto add = client->sync_newAddition(); for (auto n = 100; n--;) { add.sync_noop(); } } EXPECT_EQ(eventHandler->countInteractions(), 0); }