/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::any::TypeId; use std::collections::HashMap; use std::collections::HashSet; use anyhow::anyhow; use anyhow::bail; use anyhow::ensure; use anyhow::Result; use fbthrift::binary_protocol; use fbthrift::compact_protocol; use fbthrift::simplejson_protocol; use fbthrift::GetUri; use protocol::StandardProtocol; use crate::universal_name::ensure_registered; use crate::universal_name::get_universal_hash; use crate::universal_name::get_universal_hash_prefix_sha_256; use crate::universal_name::UniversalHashAlgorithm; const UNIVERSAL_HASH_PREFIX_SHA_256_LEN: i8 = 16; pub struct AnyRegistry { uri_to_typeid: HashMap<&'static str, TypeId>, typeid_to_uri: HashMap, alg_to_hashes: HashMap>>, } pub trait SerializeRef = binary_protocol::SerializeRef + compact_protocol::SerializeRef + simplejson_protocol::SerializeRef; pub trait DeserializeSlice = binary_protocol::DeserializeSlice + compact_protocol::DeserializeSlice + simplejson_protocol::DeserializeSlice; impl AnyRegistry { pub fn new() -> Self { let mut alg_to_hashes = HashMap::new(); alg_to_hashes.insert(UniversalHashAlgorithm::Sha2_256, HashSet::new()); Self { uri_to_typeid: HashMap::new(), typeid_to_uri: HashMap::new(), alg_to_hashes, } } pub fn register_type(&mut self) -> Result { let uri = T::uri(); let type_id = TypeId::of::(); if self.uri_to_typeid.contains_key(uri) || self.typeid_to_uri.contains_key(&type_id) { return Ok(false); } self.uri_to_typeid.insert(uri, type_id); self.typeid_to_uri.insert(type_id, uri); for (alg, hashes) in self.alg_to_hashes.iter_mut() { let hash = get_universal_hash(*alg, uri)?; hashes.insert(hash); } Ok(true) } pub fn has_type(&self, obj: &any::Any) -> Result { let type_uri = T::uri(); let type_hash_prefix_sha2_256 = get_universal_hash_prefix_sha_256(type_uri, UNIVERSAL_HASH_PREFIX_SHA_256_LEN)?; if let Some(obj_hash_prefix_sha2_256) = &obj.typeHashPrefixSha2_256 { Ok(type_hash_prefix_sha2_256 == *obj_hash_prefix_sha2_256) } else if let Some(obj_uri) = obj.r#type.as_ref() { Ok(type_uri == obj_uri) } else { bail!("No type information found"); } } pub fn store( &mut self, obj: &T, protocol: StandardProtocol, ) -> Result { let type_id = TypeId::of::(); let uri = self .typeid_to_uri .get(&type_id) .ok_or_else(|| anyhow!("Type {:?} not registered", type_id))?; let hash_prefix = get_universal_hash_prefix_sha_256(uri, UNIVERSAL_HASH_PREFIX_SHA_256_LEN)?; Ok(any::Any { r#type: Some(uri.to_string()), typeHashPrefixSha2_256: Some(hash_prefix), protocol: Some(protocol), data: serialize(obj, protocol)?, ..Default::default() }) } pub fn load(&self, obj: &any::Any) -> Result { if let Some(type_hash_prefix_sha2_256) = &obj.typeHashPrefixSha2_256 { let hashes = self .alg_to_hashes .get(&UniversalHashAlgorithm::Sha2_256) .ok_or_else(|| anyhow!("Hash algorithm Sha2_256 not supported"))?; ensure_registered(hashes, type_hash_prefix_sha2_256)?; } else if let Some(uri) = obj.r#type.as_ref() { ensure!( self.uri_to_typeid.contains_key(uri.as_str()), "Type '{}' not registered", uri ); } else { bail!("No type information found"); } deserialize(&obj.data, obj.protocol.unwrap_or(StandardProtocol::Compact)) } } fn serialize(obj: &T, protocol: StandardProtocol) -> Result> { match protocol { StandardProtocol::Binary => Ok(binary_protocol::serialize(obj).to_vec()), StandardProtocol::Compact => Ok(compact_protocol::serialize(obj).to_vec()), StandardProtocol::SimpleJson => Ok(simplejson_protocol::serialize(obj).to_vec()), unsupported => Err(anyhow!("Unsupported protocol: {:?}", unsupported)), } } fn deserialize(data: &[u8], protocol: StandardProtocol) -> Result { match protocol { StandardProtocol::Binary => binary_protocol::deserialize(data), StandardProtocol::Compact => compact_protocol::deserialize(data), StandardProtocol::SimpleJson => simplejson_protocol::deserialize(data), unsupported => Err(anyhow!("Unsupported protocol: {:?}", unsupported)), } } #[cfg(test)] mod tests { use anyhow::Result; use maplit::btreemap; use testset::struct_map_string_i32; use super::*; fn get_test_object() -> struct_map_string_i32 { struct_map_string_i32 { field_1: btreemap! {"Answer to the Ultimate Question of Life, the Universe, and Everything.".to_owned() => 42}, ..Default::default() } } fn test_serialize_round_trip_for_protocol(protocol: StandardProtocol) -> Result<()> { let original = get_test_object(); let serialized = serialize(&original, protocol)?; let deserialized = deserialize(&serialized, protocol)?; assert_eq!(original, deserialized); Ok(()) } #[test] fn test_round_trip() -> Result<()> { let protocols = vec![ StandardProtocol::Binary, StandardProtocol::Compact, StandardProtocol::SimpleJson, ]; for protocol in protocols { test_serialize_round_trip_for_protocol(protocol)?; } Ok(()) } #[test] fn test_unsupported_protocol() { let test_object = get_test_object(); assert!(serialize(&test_object, StandardProtocol::Custom).is_err()); assert!(serialize(&test_object, StandardProtocol::Json).is_err()); assert!(deserialize::(b"", StandardProtocol::Custom).is_err()); assert!(deserialize::(b"", StandardProtocol::Json).is_err()); } }