/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include using apache::thrift::ClientConnectionIf; using apache::thrift::ClientReceiveState; using apache::thrift::RequestCallback; using facebook::thrift::benchmarks::QPSStats; template class LoadCallback; template class Runner { public: friend class LoadCallback; Runner( std::unique_ptr> ops, std::unique_ptr> distribution, int32_t max_outstanding_ops) : ops_(std::move(ops)), d_(std::move(distribution)), max_outstanding_ops_(max_outstanding_ops) {} void run() { // TODO: Implement sync calls. while (ops_->outstandingOps() < max_outstanding_ops_ && !exiting_) { auto op = static_cast((*d_)(gen_)); auto cb = std::make_unique>(this, ops_.get(), op); ops_->async(op, std::move(cb)); } } void loopUntilExit(folly::EventBase* evb) { exiting_ = true; while (ops_->outstandingOps()) { evb->loopOnce(); } } void finishCall() { if (!exiting_) { run(); // Attempt to perform more async calls } } private: bool exiting_{false}; std::unique_ptr> ops_; std::unique_ptr> d_; int32_t max_outstanding_ops_; std::mt19937 gen_{std::random_device()()}; }; template class LoadCallback : public RequestCallbackWithValidator { public: LoadCallback( Runner* runner, Operation* ops, OP_TYPE op) : runner_(runner), ops_(ops), op_(op) {} void setIsOneway() { isOneway_ = true; } // TODO: Properly handle errors and exceptions void requestSent() override { if (isOneway_) { ops_->onewaySent(op_); runner_->finishCall(); } } void replyReceived(ClientReceiveState&& rstate) override { if (validator) { validator(rstate); } ops_->asyncReceived(op_, std::move(rstate)); runner_->finishCall(); } void requestError(ClientReceiveState&& rstate) override { ops_->asyncErrorReceived(op_, std::move(rstate)); runner_->finishCall(); } private: Runner* runner_; Operation* ops_; OP_TYPE op_; bool isOneway_{false}; };