/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include /** * Specializations of `protocol_methods` encapsulate a collection of * read/write/size/sizeZC methods that can be performed on Thrift * objects and primitives. TypeClass (see apache::thrift::type_class) * refers to the general type of data that Type is, and is passed around for * two reasons: * - to provide support for generic containers which have a common interface * for building collections (e.g. a `std::vector` and `std::deque`, * which can back a Thrift list, and thus have * `type_class::list`, or an * `std::map` would have * `type_class::map``). * - to differentiate between Thrift types that are represented with the * same C++ type, e.g. both Thrift binaries and strings are represented * with an `std::string`, TypeClass is used to decide which protocol * method to call. * * Example: * * // MyModule.thrift: * struct MyStruct { * 1: list> fieldA * } * * // C++ * * using methods = protocol_methods< * type_class::list>, * std::vector>> * * MyStruct struct_instance; * CompactProtocolReader reader; * methods::read(struct_instance.fieldA, reader); */ namespace apache { namespace thrift { namespace type { namespace detail { template class Wrap; } } // namespace type namespace detail { namespace pm { template using detect_reserve = decltype(FOLLY_DECLVAL(C).reserve(FOLLY_DECLVAL(A)...)); template auto reserve_if_possible(Container* t, Size size) { if constexpr (folly::is_detected_v) { t->reserve(size); return std::true_type{}; } else { return std::false_type{}; } } template typename Container::reference emplace_back_default(Container& c) { return c.emplace_back(detail::default_set_element(c)); } template typename Container::reference emplace_back_default_map(Container& c, Map& m) { return c.emplace_back( detail::default_map_key(m), detail::default_map_value(m)); } template std::enable_if_t> deserialize_key_val_into_map( Map& m, const KeyDeserializer& kr, const MappedDeserializer& mr) { typename Map::key_type key = detail::default_map_key(m); typename Map::mapped_type value = detail::default_map_value(m); kr(key); mr(value); m.emplace(std::move(key), std::move(value)); } template std::enable_if_t> deserialize_key_val_into_map( Map& m, const KeyDeserializer& kr, const MappedDeserializer& mr) { typename Map::key_type key; // Create key/val without allocator awareness. kr(key); mr(m[std::move(key)]); } template inline constexpr bool sorted_unique_constructible_ = false; template inline constexpr bool sorted_unique_constructible_< folly::void_t< decltype(T(folly::sorted_unique, typename T::container_type())), decltype(T(typename T::container_type()))>, T> = true; template inline constexpr bool sorted_unique_constructible_v = sorted_unique_constructible_; FOLLY_CREATE_MEMBER_INVOKER(emplace_hint_invoker, emplace_hint); template using detect_key_compare = typename T::key_compare; template constexpr bool map_emplace_hint_is_invocable_v = folly::is_invocable_v< emplace_hint_invoker, T, typename T::iterator, typename T::key_type, typename T::mapped_type>; template constexpr bool set_emplace_hint_is_invocable_v = folly::is_invocable_v< emplace_hint_invoker, T, typename T::iterator, typename T::value_type>; template typename std::enable_if>::type deserialize_known_length_map( Map& map, std::uint32_t map_size, const KeyDeserializer& kr, const MappedDeserializer& mr) { if (map_size == 0) { return; } bool sorted = true; typename Map::container_type tmp(map.get_allocator()); reserve_if_possible(&tmp, map_size); { decltype(auto) elem0 = emplace_back_default_map(tmp, map); kr(elem0.first); mr(elem0.second); } for (size_t i = 1; i < map_size; ++i) { decltype(auto) elem = emplace_back_default_map(tmp, map); kr(elem.first); mr(elem.second); sorted = sorted && map.key_comp()(tmp[i - 1].first, elem.first); } using folly::sorted_unique; map = sorted ? Map(sorted_unique, std::move(tmp)) : Map(std::move(tmp)); } template typename std::enable_if< !sorted_unique_constructible_v && map_emplace_hint_is_invocable_v>::type deserialize_known_length_map( Map& map, std::uint32_t map_size, const KeyDeserializer& kr, const MappedDeserializer& mr) { reserve_if_possible(&map, map_size); for (auto i = map_size; i--;) { typename Map::key_type key = detail::default_map_key(map); typename Map::mapped_type value = detail::default_map_value(map); kr(key); mr(value); map.emplace_hint(map.end(), std::move(key), std::move(value)); } } template typename std::enable_if< !sorted_unique_constructible_v && !map_emplace_hint_is_invocable_v>::type deserialize_known_length_map( Map& map, std::uint32_t map_size, const KeyDeserializer& kr, const MappedDeserializer& mr) { reserve_if_possible(&map, map_size); for (auto i = map_size; i--;) { deserialize_key_val_into_map(map, kr, mr); } } template typename std::enable_if>::type deserialize_known_length_set( Set& set, std::uint32_t set_size, const ValDeserializer& vr) { if (set_size == 0) { return; } bool sorted = true; typename Set::container_type tmp(set.get_allocator()); reserve_if_possible(&tmp, set_size); { auto& elem0 = emplace_back_default(tmp); vr(elem0); } for (size_t i = 1; i < set_size; ++i) { auto& elem = emplace_back_default(tmp); vr(elem); sorted = sorted && set.key_comp()(tmp[i - 1], elem); } using folly::sorted_unique; set = sorted ? Set(sorted_unique, std::move(tmp)) : Set(std::move(tmp)); } template typename std::enable_if< !sorted_unique_constructible_v && set_emplace_hint_is_invocable_v>::type deserialize_known_length_set( Set& set, std::uint32_t set_size, const ValDeserializer& vr) { reserve_if_possible(&set, set_size); for (auto i = set_size; i--;) { typename Set::value_type value = detail::default_set_element(set); vr(value); set.emplace_hint(set.end(), std::move(value)); } } template typename std::enable_if< !sorted_unique_constructible_v && !set_emplace_hint_is_invocable_v>::type deserialize_known_length_set( Set& set, std::uint32_t set_size, const ValDeserializer& vr) { reserve_if_possible(&set, set_size); for (auto i = set_size; i--;) { typename Set::value_type value = detail::default_set_element(set); vr(value); set.insert(std::move(value)); } } inline uint32_t checked_container_size(size_t size) { const size_t limit = std::numeric_limits::max(); if (size > limit) { TProtocolException::throwExceededSizeLimit(size, limit); } return static_cast(size); } /* * Primitive Types Specialization */ template struct protocol_methods; #define THRIFT_PROTOCOL_METHODS_REGISTER_RW_COMMON(Class, Type, Method) \ template \ static void read(Protocol& protocol, Type& out) { \ protocol.read##Method(out); \ } \ template \ static void readWithContext(Protocol& protocol, Type& out, Context& ctx) { \ if constexpr (Context::kAcceptsContext) { \ protocol.read##Method##WithContext(out, ctx); \ } else { \ protocol.read##Method(out); \ } \ } \ template \ static std::size_t write(Protocol& protocol, const Type& in) { \ if constexpr ( \ std::is_same::value || \ std::is_same::value) { \ return checked_container_size(protocol.write##Method(in)); \ } else { \ return protocol.write##Method(in); \ } \ } #define THRIFT_PROTOCOL_METHODS_REGISTER_SS_COMMON(Class, Type, Method) \ template \ static std::size_t serializedSize(Protocol& protocol, const Type& in) { \ return protocol.serializedSize##Method(in); \ } // stamp out specializations for primitive types #define THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(Class, Type, Method) \ template <> \ struct protocol_methods { \ THRIFT_PROTOCOL_METHODS_REGISTER_RW_COMMON(Class, Type, Method) \ THRIFT_PROTOCOL_METHODS_REGISTER_SS_COMMON(Class, Type, Method) \ } THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(integral, std::int8_t, Byte); THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(integral, std::int16_t, I16); THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(integral, std::int32_t, I32); THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(integral, std::int64_t, I64); // Macros for defining protocol_methods for unsigned integers // Need special macros due to the casts needed #define THRIFT_PROTOCOL_METHODS_REGISTER_RW_UI(Class, Type, Method) \ using SignedType = std::make_signed_t; \ template \ static void read(Protocol& protocol, Type& out) { \ SignedType tmp; \ protocol.read##Method(tmp); \ out = folly::to_unsigned(tmp); \ } \ template \ static void readWithContext(Protocol& protocol, Type& out, Context& ctx) { \ SignedType tmp; \ if constexpr (Context::kAcceptsContext) { \ protocol.read##Method##WithContext(tmp, ctx); \ } else { \ protocol.read##Method(tmp); \ } \ out = folly::to_unsigned(tmp); \ } \ template \ static std::size_t write(Protocol& protocol, const Type& in) { \ return protocol.write##Method(folly::to_signed(in)); \ } #define THRIFT_PROTOCOL_METHODS_REGISTER_SS_UI(Class, Type, Method) \ template \ static std::size_t serializedSize(Protocol& protocol, const Type& in) { \ return protocol.serializedSize##Method(folly::to_signed(in)); \ } // stamp out specializations for unsigned integer primitive types #define THRIFT_PROTOCOL_METHODS_REGISTER_UI(Class, Type, Method) \ template <> \ struct protocol_methods { \ THRIFT_PROTOCOL_METHODS_REGISTER_RW_UI(Class, Type, Method) \ THRIFT_PROTOCOL_METHODS_REGISTER_SS_UI(Class, Type, Method) \ } THRIFT_PROTOCOL_METHODS_REGISTER_UI(integral, std::uint8_t, Byte); THRIFT_PROTOCOL_METHODS_REGISTER_UI(integral, std::uint16_t, I16); THRIFT_PROTOCOL_METHODS_REGISTER_UI(integral, std::uint32_t, I32); THRIFT_PROTOCOL_METHODS_REGISTER_UI(integral, std::uint64_t, I64); #undef THRIFT_PROTOCOL_METHODS_REGISTER_UI #undef THRIFT_PROTOCOL_METHODS_REGISTER_RW_UI #undef THRIFT_PROTOCOL_METHODS_REGISTER_SS_UI // std::vector isn't actually a container, so // define a special overload which takes its specialized // proxy type template <> struct protocol_methods { THRIFT_PROTOCOL_METHODS_REGISTER_RW_COMMON(integral, bool, Bool) THRIFT_PROTOCOL_METHODS_REGISTER_SS_COMMON(integral, bool, Bool) template static void read(Protocol& protocol, std::vector::reference out) { bool tmp; read(protocol, tmp); out = tmp; } }; THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(floating_point, double, Double); THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD(floating_point, float, Float); #undef THRIFT_PROTOCOL_METHODS_REGISTER_OVERLOAD template struct protocol_methods { THRIFT_PROTOCOL_METHODS_REGISTER_RW_COMMON(string, Type, String) THRIFT_PROTOCOL_METHODS_REGISTER_SS_COMMON(string, Type, String) }; template struct protocol_methods { THRIFT_PROTOCOL_METHODS_REGISTER_RW_COMMON(binary, Type, Binary) template static typename std::enable_if::type serializedSize( Protocol& protocol, const Type& in) { return protocol.serializedSizeZCBinary(in); } template static typename std::enable_if::type serializedSize( Protocol& protocol, const Type& in) { return protocol.serializedSizeBinary(in); } }; #undef THRIFT_PROTOCOL_METHODS_REGISTER_SS_COMMON #undef THRIFT_PROTOCOL_METHODS_REGISTER_RW_COMMON /* * Enum Specialization */ template > struct enum_protocol_methods { static_assert(std::is_enum::value, "must be enum"); using int_methods = protocol_methods; template static void read(Protocol& protocol, Type& out) { int_type tmp; int_methods::read(protocol, tmp); out = static_cast(tmp); } template static void readWithContext(Protocol& protocol, Type& out, Context& ctx) { int_type tmp; int_methods::readWithContext(protocol, tmp, ctx); out = static_cast(tmp); } template static std::size_t write(Protocol& protocol, const Type& in) { int_type tmp = static_cast(in); return int_methods::template write(protocol, tmp); } template static std::size_t serializedSize(Protocol& protocol, const Type& in) { int_type tmp = static_cast(in); return int_methods::template serializedSize(protocol, tmp); } }; // Thrift enums are always read as int32_t template struct protocol_methods : enum_protocol_methods {}; // Strong integral types keep their precision. template struct protocol_methods< type_class::integral, Type, std::enable_if_t::value>> : enum_protocol_methods { }; /* * List Specialization */ template struct protocol_methods, Type> { static_assert( !std::is_same(), "Unable to serialize unknown list element"); using elem_type = folly::remove_cvref_t; using elem_methods = protocol_methods; using elem_ttype = protocol_type; private: template FOLLY_ERASE static void read_one(Protocol& protocol, Type& out) { if constexpr ( // std::is_const_v>) { out.emplace_back(folly::invocable_to([&] { elem_type elem; elem_methods::read(protocol, elem); return elem; })); } else { elem_methods::read(protocol, emplace_back_default(out)); } } public: template static void read(Protocol& protocol, Type& out) { std::uint32_t list_size = -1; using WireTypeInfo = ProtocolReaderWireTypeInfo; using WireType = typename WireTypeInfo::WireType; WireType reported_type = WireTypeInfo::defaultValue(); protocol.readListBegin(reported_type, list_size); if (protocol.kOmitsContainerSizes()) { // list size unknown, SimpleJSON protocol won't know type, either // so let's just hope that it spits out something that makes sense while (protocol.peekList()) { read_one(protocol, out); } } else { if (reported_type != WireTypeInfo::fromTType(elem_ttype::value)) { apache::thrift::skip_n(protocol, list_size, {reported_type}); } else { if (!canReadNElements(protocol, list_size, {reported_type})) { protocol::TProtocolException::throwTruncatedData(); } reserve_if_possible(&out, list_size); while (list_size--) { read_one(protocol, out); } } } protocol.readListEnd(); } template static void readWithContext(Protocol& protocol, Type& out, Context&) { read(protocol, out); } template static std::size_t write(Protocol& protocol, const Type& out) { std::size_t xfer = 0; xfer += protocol.writeListBegin( elem_ttype::value, checked_container_size(out.size())); for (const auto& elem : out) { xfer += elem_methods::write(protocol, elem); } xfer += protocol.writeListEnd(); return xfer; } template static std::size_t serializedSize(Protocol& protocol, const Type& out) { std::size_t xfer = 0; xfer += protocol.serializedSizeListBegin( elem_ttype::value, folly::to_narrow(folly::to_unsigned(out.size()))); for (const auto& elem : out) { xfer += elem_methods::template serializedSize(protocol, elem); } xfer += protocol.serializedSizeListEnd(); return xfer; } }; /* * Set Specialization */ template struct protocol_methods, Type> { static_assert( !std::is_same(), "Unable to serialize unknown type"); using elem_type = typename Type::value_type; using elem_methods = protocol_methods; using elem_ttype = protocol_type; private: template static void consume_elem(Protocol& protocol, Type& out) { elem_type tmp; elem_methods::read(protocol, tmp); out.insert(std::move(tmp)); } public: template static void read(Protocol& protocol, Type& out) { std::uint32_t set_size = -1; using WireTypeInfo = ProtocolReaderWireTypeInfo; using WireType = typename WireTypeInfo::WireType; WireType reported_type = WireTypeInfo::defaultValue(); protocol.readSetBegin(reported_type, set_size); if (protocol.kOmitsContainerSizes()) { while (protocol.peekSet()) { consume_elem(protocol, out); } } else { if (reported_type != WireTypeInfo::fromTType(elem_ttype::value)) { apache::thrift::skip_n(protocol, set_size, {reported_type}); } else { if (!canReadNElements(protocol, set_size, {reported_type})) { protocol::TProtocolException::throwTruncatedData(); } const auto vreader = [&protocol](auto& value) { elem_methods::read(protocol, value); }; deserialize_known_length_set(out, set_size, vreader); } } protocol.readSetEnd(); } template static void readWithContext(Protocol& protocol, Type& out, Context&) { read(protocol, out); } template static std::size_t write(Protocol& protocol, const Type& out) { std::size_t xfer = 0; xfer += protocol.writeSetBegin( elem_ttype::value, checked_container_size(out.size())); if (!folly::is_detected_v && protocol.kSortKeys()) { std::vector iters; iters.reserve(out.size()); for (auto it = out.begin(); it != out.end(); ++it) { iters.push_back(it); } std::sort( iters.begin(), iters.end(), [](auto a, auto b) { return *a < *b; }); for (auto it : iters) { xfer += elem_methods::write(protocol, *it); } } else { // Support containers with defined but non-FIFO iteration order. auto get_view = folly::order_preserving_reinsertion_view_or_default; for (const auto& elem : get_view(out)) { xfer += elem_methods::write(protocol, elem); } } xfer += protocol.writeSetEnd(); return xfer; } template static std::size_t serializedSize(Protocol& protocol, const Type& out) { std::size_t xfer = 0; xfer += protocol.serializedSizeSetBegin( elem_ttype::value, folly::to_narrow(folly::to_unsigned(out.size()))); for (const auto& elem : out) { xfer += elem_methods::template serializedSize(protocol, elem); } xfer += protocol.serializedSizeSetEnd(); return xfer; } }; /* * Map Specialization */ template struct protocol_methods, Type> { static_assert( !std::is_same(), "Unable to serialize unknown key type in map"); static_assert( !std::is_same(), "Unable to serialize unknown mapped type in map"); using key_type = typename Type::key_type; using mapped_type = typename Type::mapped_type; using key_methods = protocol_methods; using mapped_methods = protocol_methods; using key_ttype = protocol_type; using mapped_ttype = protocol_type; protected: template static void consume_elem(Protocol& protocol, U& out) { key_type key_tmp; key_methods::read(protocol, key_tmp); mapped_methods::read(protocol, out[std::move(key_tmp)]); } public: template static void read(Protocol& protocol, U& out) { std::uint32_t map_size = -1; using WireTypeInfo = ProtocolReaderWireTypeInfo; using WireType = typename WireTypeInfo::WireType; WireType rpt_key_type = WireTypeInfo::defaultValue(), rpt_mapped_type = WireTypeInfo::defaultValue(); protocol.readMapBegin(rpt_key_type, rpt_mapped_type, map_size); if (protocol.kOmitsContainerSizes()) { while (protocol.peekMap()) { consume_elem(protocol, out); } } else { // CompactProtocol does not transmit key/mapped types if // the map is empty if (map_size > 0 && (WireTypeInfo::fromTType(key_ttype::value) != rpt_key_type || WireTypeInfo::fromTType(mapped_ttype::value) != rpt_mapped_type)) { apache::thrift::skip_n( protocol, map_size, {rpt_key_type, rpt_mapped_type}); } else { if (!canReadNElements( protocol, map_size, {rpt_key_type, rpt_mapped_type})) { protocol::TProtocolException::throwTruncatedData(); } const auto kreader = [&protocol](auto& key) { key_methods::read(protocol, key); }; const auto vreader = [&protocol](auto& value) { mapped_methods::read(protocol, value); }; deserialize_known_length_map(out, map_size, kreader, vreader); } } protocol.readMapEnd(); } template static void readWithContext(Protocol& protocol, U& out, Context&) { read(protocol, out); } template static std::size_t write(Protocol& protocol, const U& out) { std::size_t xfer = 0; xfer += protocol.writeMapBegin( key_ttype::value, mapped_ttype::value, checked_container_size(out.size())); if (!folly::is_detected_v && protocol.kSortKeys()) { std::vector iters; iters.reserve(out.size()); for (auto it = out.begin(); it != out.end(); ++it) { iters.push_back(it); } std::sort(iters.begin(), iters.end(), [](auto a, auto b) { return (*a).first < (*b).first; }); for (auto it : iters) { xfer += writeMapValueBegin(protocol); xfer += key_methods::write(protocol, (*it).first); xfer += mapped_methods::write(protocol, (*it).second); xfer += writeMapValueEnd(protocol); } } else { // Support containers with defined but non-FIFO iteration order. auto get_view = folly::order_preserving_reinsertion_view_or_default; for (const auto& elem_pair : get_view(out)) { xfer += writeMapValueBegin(protocol); xfer += key_methods::write(protocol, elem_pair.first); xfer += mapped_methods::write(protocol, elem_pair.second); xfer += writeMapValueEnd(protocol); } } xfer += protocol.writeMapEnd(); return xfer; } template static std::size_t serializedSize(Protocol& protocol, const U& out) { std::size_t xfer = protocol.serializedSizeMapBegin( key_ttype::value, mapped_ttype::value, folly::to_narrow(folly::to_unsigned(out.size()))); for (const auto& elem_pair : out) { xfer += key_methods::template serializedSize( protocol, elem_pair.first); xfer += mapped_methods::template serializedSize( protocol, elem_pair.second); } xfer += protocol.serializedSizeMapEnd(); return xfer; } private: template using map_value_begin_t = decltype(std::declval().writeMapValueBegin()); template using map_value_end_t = decltype(std::declval().writeMapValueEnd()); template static constexpr bool map_value_api_v = folly::is_detected_v && folly::is_detected_v; template static std::size_t writeMapValueBegin(Protocol& protocol) { const auto writeMapValueBeginFunc = std::get>(std::make_pair( [](auto&) { return 0u; }, [](auto& protocolWithMapValueApi) { return protocolWithMapValueApi.writeMapValueBegin(); })); return writeMapValueBeginFunc(protocol); } template static std::size_t writeMapValueEnd(Protocol& protocol) { const auto writeMapValueEndFunc = std::get>(std::make_pair( [](auto&) { return 0u; }, [](auto& protocolWithMapValueApi) { return protocolWithMapValueApi.writeMapValueEnd(); })); return writeMapValueEndFunc(protocol); } }; /* * Struct with Indirection Specialization */ template struct protocol_methods, Type> { using indirection = Indirection; using elem_type = std::remove_reference_t>; using elem_methods = protocol_methods; template static void read(Protocol& protocol, Type& out) { elem_methods::read(protocol, indirection{}(out)); } template static void readWithContext(Protocol& protocol, Type& out, Context& ctx) { elem_methods::readWithContext(protocol, indirection{}(out), ctx); } template static std::size_t write(Protocol& protocol, const Type& in) { return elem_methods::write(protocol, indirection{}(in)); } template static std::size_t serializedSize(Protocol& protocol, const Type& in) { return elem_methods::template serializedSize( protocol, indirection{}(in)); } }; /* * Struct Specialization * Forwards to Cpp2Ops wrapper around member read/write/etc. */ template struct protocol_methods { template using Wrap = type::detail::Wrap; static Type& unwrap(Type& inst) { return inst; } static const Type& unwrap(const Type& inst) { return inst; } template static Type& unwrap(Wrap& inst) { return inst.toThrift(); } template static const Type& unwrap(const Wrap& inst) { return inst.toThrift(); } template static void read(Protocol& protocol, U& out) { Cpp2Ops::read(&protocol, &unwrap(out)); } template static void readWithContext(Protocol& protocol, U& out, Context&) { read(protocol, out); } template static std::size_t write(Protocol& protocol, const U& in) { return Cpp2Ops::write(&protocol, &unwrap(in)); } template static std::size_t serializedSize(Protocol& protocol, const U& in) { if (ZeroCopy) { return Cpp2Ops::serializedSizeZC(&protocol, &unwrap(in)); } else { return Cpp2Ops::serializedSize(&protocol, &unwrap(in)); } } }; /* * Union Specialization * Forwards to Cpp2Ops wrapper around member read/write/etc. */ template struct protocol_methods : protocol_methods {}; } // namespace pm } // namespace detail } // namespace thrift } // namespace apache