/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Functions only intended to be called from code generated by the Rust Thrift // code generator. use std::ffi::CStr; use std::fmt; use std::fmt::Display; use std::future::Future; use std::pin::Pin; use anyhow::bail; use anyhow::Context; use anyhow::Result; use bytes::Buf; use futures::future::FutureExt; use crate::serialize; use crate::ApplicationException; use crate::BufExt; use crate::ContextStack; use crate::Deserialize; use crate::MessageType; use crate::Protocol; use crate::ProtocolEncodedFinal; use crate::ProtocolReader; use crate::ProtocolWriter; use crate::RequestContext; use crate::ResultInfo; use crate::ResultType; use crate::Serialize; use crate::SerializedMessage; // Note: `variants_by_number` must be sorted by the i32 values. pub fn enum_display( variants_by_number: &[(&str, i32)], formatter: &mut fmt::Formatter, number: i32, ) -> fmt::Result { match variants_by_number.binary_search_by_key(&number, |entry| entry.1) { Ok(i) => variants_by_number[i].0.fmt(formatter), Err(_) => number.fmt(formatter), } } // Note: `variants_by_name` must be sorted by the string values. pub fn enum_from_str( variants_by_name: &[(&str, i32)], value: &str, type_name: &'static str, ) -> Result { match variants_by_name.binary_search_by_key(&value, |entry| entry.0) { Ok(i) => Ok(variants_by_name[i].1), Err(_) => bail!("Unable to parse {} as {}", value, type_name), } } pub fn type_name_of_val(_: &T) -> &'static str { std::any::type_name::() } pub fn buf_len(b: &B) -> anyhow::Result { let length: usize = b.remaining(); let length: u32 = length.try_into().with_context(|| { format!("Unable to report a buffer length of {length} bytes as a `u32`") })?; Ok(length) } /// Serialize a result as encoded into a generated *Exn type, wrapped in an envelope. pub fn serialize_result_envelope( name: &str, name_cstr: &::Name, seqid: u32, rctxt: &CTXT, ctx_stack: &mut CTXT::ContextStack, res: RES, ) -> anyhow::Result> where P: Protocol, RES: ResultInfo + Serialize + Serialize, CTXT: RequestContext, ::ContextStack: ContextStack, ProtocolEncodedFinal

: Clone + Buf + BufExt, { let res_type = res.result_type(); if matches!(res_type, ResultType::Error | ResultType::Exception) { assert_eq!(res.exn_is_declared(), res_type == ResultType::Error); rctxt.set_user_exception_header(res.exn_name(), &res.exn_value())?; } ctx_stack.pre_write()?; let envelope = serialize!(P, |p| { p.write_message_begin(name, res_type.message_type(), seqid); res.write(p); p.write_message_end(); }); ctx_stack.on_write_data(SerializedMessage { protocol: P::PROTOCOL_ID, method_name: name_cstr, buffer: envelope.clone(), })?; let bytes_written = buf_len(&envelope)?; ctx_stack.post_write(bytes_written)?; Ok(envelope) } pub fn serialize_stream_item(res: RES) -> anyhow::Result> where P: Protocol, RES: ResultInfo + Serialize + Serialize, { Ok(serialize!(P, |p| { res.write(p); })) } /// Serialize a request with envelope. pub fn serialize_request_envelope( name: &str, args: &ARGS, ) -> anyhow::Result> where P: Protocol, ARGS: Serialize + Serialize, { let envelope = serialize!(P, |p| { // Note: we send a 0 message sequence ID from clients because // this field should not be used by the server (except for some // language implementations). p.write_message_begin(name, MessageType::Call, 0); args.write(p); p.write_message_end(); }); Ok(envelope) } /// Deserialize a client response. This deserializes the envelope then /// deserializes either a reply or an ApplicationException. pub fn deserialize_response_envelope( de: &mut P::Deserializer, ) -> anyhow::Result> where P: Protocol, T: Deserialize, { let (_, message_type, _) = de.read_message_begin(|_| ())?; let res = match message_type { MessageType::Reply => Ok(T::read(de)?), MessageType::Exception => Err(ApplicationException::read(de)?), MessageType::Call | MessageType::Oneway | MessageType::InvalidMessageType => { bail!("Unwanted message type `{:?}`", message_type) } }; de.read_message_end()?; Ok(res) } /// Abstract spawning some potentially CPU-heavy work onto a CPU thread pub trait Spawner: 'static { fn spawn(func: F) -> Pin + Send>> where F: FnOnce() -> R + Send + 'static, R: Send + 'static; } /// No-op implementation of Spawner - just run on current thread pub struct NoopSpawner; impl Spawner for NoopSpawner { #[inline] fn spawn(func: F) -> Pin + Send>> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { async { func() }.boxed() } } pub async fn async_deserialize_response_envelope( de: P::Deserializer, ) -> anyhow::Result<(Result, P::Deserializer)> where P: Protocol, P::Deserializer: Send, T: Deserialize + Send + 'static, S: Spawner, { S::spawn(move || { let mut de = de; let res = deserialize_response_envelope::(&mut de); res.map(|res| (res, de)) }) .await }