/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include #include namespace apache::thrift::protocol { // This is the return value of parseObject with mask. // Masked fields are deserialized to included Object, and the other fields are // are stored in excluded MaskedProtocolData. struct MaskedDecodeResult { Object included; MaskedProtocolData excluded; }; namespace detail { // Validates the mask with the given Struct. Ensures that mask doesn't contain // fields not in the Struct. template bool validate_mask(MaskRef ref) { // Get the field ids in the thrift struct type. std::unordered_set ids; ids.reserve(op::size_v); op::for_each_ordinal( [&](auto ord) { ids.insert(op::get_field_id()); }); const FieldIdToMask& map = ref.mask.includes_ref() ? ref.mask.includes_ref().value() : ref.mask.excludes_ref().value(); for (auto& [id, _] : map) { // Mask contains a field not in the struct. if (ids.find(FieldId{id}) == ids.end()) { return false; } } return true; } template bool is_compatible_with(const Mask&); template bool is_compatible_with_impl(Tag, const Mask&) { return false; } template bool is_compatible_with_structured(const Mask& mask) { MaskRef ref{mask}; if (!validate_mask(ref)) { return false; } // Validates each field in the struct/union bool isValid = true; op::for_each_ordinal([&](auto ord) { if (!isValid) { // short circuit return; } using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); if (next.isAllMask() || next.isNoneMask()) { return; } // Recurse isValid &= is_compatible_with>(next.mask); }); return isValid; } template bool is_compatible_with_impl(type::struct_t, const Mask& mask) { return is_compatible_with_structured(mask); } template bool is_compatible_with_impl(type::union_t, const Mask& mask) { return is_compatible_with_structured(mask); } template bool is_compatible_with_impl(type::map, const Mask& mask) { // Map mask is compatible only if all nested masks are compatible with // `Value`. if (const auto* m = getIntegerMapMask(mask)) { return std::all_of(m->begin(), m->end(), [](const auto& pair) { return is_compatible_with(pair.second); }); } if (const auto* m = getStringMapMask(mask)) { return std::all_of(m->begin(), m->end(), [](const auto& pair) { return is_compatible_with(pair.second); }); } return true; } template bool is_compatible_with(const Mask& mask) { if (isAllMask(mask) || isNoneMask(mask)) { return true; } return is_compatible_with_impl(Tag{}, mask); } // Throws an error if a thrift struct type is not compatible with the mask. template void errorIfNotCompatible(const Mask& mask) { if (!::apache::thrift::protocol::detail::is_compatible_with(mask)) { folly::throw_exception( "The mask and struct are incompatible."); } } // This converts id list to a field mask with a single field. template Mask path(const Mask& other) { // This is the base case as there is no more id. errorIfNotCompatible(other); return other; } template Mask path(const Mask& other) { using T = type::native_type; static_assert(is_thrift_class_v); Mask mask; using fieldId = op::get_field_id; static_assert(fieldId::value != FieldId{}); mask.includes_ref().emplace()[static_cast(fieldId::value)] = path, Ids...>(other); return mask; } // This converts field name list from the given index to a field mask with a // single field. template Mask path( const std::vector& fieldNames, size_t index, const Mask& other) { if (index == fieldNames.size()) { errorIfNotCompatible(other); return other; } // static_assert doesn't work as it compiles this code for every field. using T = type::native_type; if constexpr (is_thrift_class_v) { Mask mask; op::for_each_field_id([&](auto id) { using Id = decltype(id); if (mask.includes_ref()) { // already set return; } if (op::get_name_v == fieldNames[index]) { mask.includes_ref().emplace()[folly::to_underlying(id())] = path>(fieldNames, index + 1, other); } }); if (!mask.includes_ref()) { // field not found folly::throw_exception("field doesn't exist"); } return mask; } folly::throw_exception( "Path contains a non thrift struct/union field."); } // Ensures the masked fields in the given thrift struct. template void ensure_fields(MaskRef ref, T& t) { if (!validate_mask(ref)) { folly::throw_exception( "The mask and struct are incompatible."); } if (is_thrift_union_v && ref.numFieldsSet() > 1) { folly::throw_exception( "Ensuring more than one field in union"); } if constexpr (!std::is_const_v>) { op::for_each_ordinal([&](auto ord) { using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); if (next.isNoneMask()) { return; } using FieldTag = op::get_field_tag; auto&& field_ref = op::get(t); op::ensure(field_ref, t); // Need to ensure the struct object. using FieldType = op::get_native_type; if constexpr (is_thrift_class_v) { auto& value = *op::getValueOrNull(field_ref); ensure_fields(next, value); return; } if (!next.isAllMask()) { folly::throw_exception( "The mask and struct are incompatible."); } }); } else { folly::throw_exception("Cannot ensure a const object"); } } // Clears the masked fields in the given thrift struct. template void clear_fields(MaskRef ref, T& t) { if (!validate_mask(ref)) { folly::throw_exception( "The mask and struct are incompatible."); } if constexpr (!std::is_const_v>) { op::for_each_ordinal([&](auto ord) { using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); if (next.isNoneMask()) { return; } using FieldTag = op::get_field_tag; auto&& field_ref = op::get(t); if (next.isAllMask()) { op::clear_field(field_ref, t); return; } using FieldType = op::get_native_type; auto* field_value = op::getValueOrNull(field_ref); if (!field_value) { errorIfNotCompatible>(next.mask); return; } // Need to clear the struct/union object. if constexpr (is_thrift_class_v) { clear_fields(next, *field_value); return; } folly::throw_exception( "The mask and struct are incompatible."); }); } else { folly::throw_exception("Cannot clear a const object"); } } // Writes masked fields from src (as specified by ref) into ret (ret must be // empty). Returns true if any masked field was written into ret. template bool filter_fields(MaskRef ref, const T& src, T& ret) { if (!validate_mask(ref)) { folly::throw_exception( "The mask and struct are incompatible."); } bool retained = false; op::for_each_ordinal([&](auto ord) { using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); // Id doesn't exist in field mask, skip. if (next.isNoneMask()) { return; } using FieldType = op::get_native_type; auto&& src_ref = op::get(src); auto&& ret_ref = op::get(ret); bool srcHasValue = bool(op::getValueOrNull(src_ref)); if (!srcHasValue) { // skip errorIfNotCompatible>(next.mask); } else if (next.isAllMask()) { if constexpr (is_thrift_union_v) { // Simply copy the entire union over ret = src; } else { op::copy(src_ref, ret_ref); } retained = true; } else if constexpr (is_thrift_class_v) { FieldType nested; // If no masked fields are retained, leave this field unset (will leave // optional fields unset) if (filter_fields(next, *src_ref, nested)) { moveObject(ret_ref, std::move(nested)); retained = true; } } else { folly::throw_exception( "The mask and struct are incompatible."); } }); return retained; } struct MaskedDecodeResultValue { Value included; MaskedData excluded; }; // Stores the serialized data of the given type in maskedData and protocolData. template void setMaskedDataFull( Protocol& prot, TType arg_type, MaskedData& maskedData, MaskedProtocolData& protocolData) { auto& values = protocolData.values().ensure(); auto& encodedValue = values.emplace_back(); encodedValue.wireType() = type::toBaseType(arg_type); // get the serialized data from cursor auto cursor = prot.getCursor(); apache::thrift::skip(prot, arg_type); cursor.clone(encodedValue.data().emplace(), prot.getCursor() - cursor); const auto pos = folly::to(values.size() - 1); maskedData.full_ref() = type::ValueId{apache::thrift::util::i32ToZigzag(pos)}; } // parseValue with readMaskRef and writeMaskRef template MaskedDecodeResultValue parseValueWithMask( Protocol& prot, TType arg_type, MaskRef readMaskRef, MaskRef writeMaskRef, MaskedProtocolData& protocolData, bool string_to_binary = true) { MaskedDecodeResultValue result; if (readMaskRef.isAllMask()) { // serialize all parseValueInplace(prot, arg_type, result.included, string_to_binary); return result; } if (readMaskRef.isNoneMask()) { // do not deserialize if constexpr (!KeepExcludedData) { // no need to store apache::thrift::skip(prot, arg_type); return result; } if (writeMaskRef.isNoneMask()) { // store the serialized data setMaskedDataFull(prot, arg_type, result.excluded, protocolData); return result; } if (writeMaskRef.isAllMask()) { // no need to store apache::thrift::skip(prot, arg_type); return result; } // Need to recursively store the result not in writeMaskRef. } switch (arg_type) { case protocol::T_STRUCT: { auto& object = result.included.ensure_object(); std::string name; int16_t fid; TType ftype; prot.readStructBegin(name); while (true) { prot.readFieldBegin(name, ftype, fid); if (ftype == protocol::T_STOP) { break; } MaskRef nextRead = readMaskRef.get(FieldId{fid}); MaskRef nextWrite = writeMaskRef.get(FieldId{fid}); MaskedDecodeResultValue nestedResult = parseValueWithMask( prot, ftype, nextRead, nextWrite, protocolData, string_to_binary); // Set nested MaskedDecodeResult if not empty. if (!apache::thrift::empty(nestedResult.included)) { object[FieldId{fid}] = std::move(nestedResult.included); } if constexpr (KeepExcludedData) { if (!apache::thrift::empty(nestedResult.excluded)) { result.excluded.fields_ref().ensure()[FieldId{fid}] = std::move(nestedResult.excluded); } } prot.readFieldEnd(); } prot.readStructEnd(); return result; } case protocol::T_MAP: { auto& map = result.included.ensure_map(); TType keyType; TType valType; uint32_t size; prot.readMapBegin(keyType, valType, size); if (!size) { prot.readMapEnd(); return result; } auto readValueIndex = buildValueIndex(readMaskRef.mask); auto writeValueIndex = buildValueIndex(writeMaskRef.mask); for (uint32_t i = 0; i < size; i++) { auto keyValue = parseValue(prot, keyType, string_to_binary); MaskRef nextRead = readMaskRef.get( getMapIdValueAddressFromIndex(readValueIndex, keyValue)); MaskRef nextWrite = writeMaskRef.get( getMapIdValueAddressFromIndex(writeValueIndex, keyValue)); MaskedDecodeResultValue nestedResult = parseValueWithMask( prot, valType, nextRead, nextWrite, protocolData, string_to_binary); // Set nested MaskedDecodeResult if not empty. if (!apache::thrift::empty(nestedResult.included)) { map[keyValue] = std::move(nestedResult.included); } if constexpr (KeepExcludedData) { if (!apache::thrift::empty(nestedResult.excluded)) { auto& keys = protocolData.keys().ensure(); keys.push_back(keyValue); const auto pos = folly::to(keys.size() - 1); type::ValueId id = type::ValueId{apache::thrift::util::i32ToZigzag(pos)}; result.excluded.values_ref().ensure()[id] = std::move(nestedResult.excluded); } } } prot.readMapEnd(); return result; } default: { parseValueInplace(prot, arg_type, result.included, string_to_binary); return result; } } } template MaskedDecodeResult parseObject( const folly::IOBuf& buf, const Mask& readMask, const Mask& writeMask, bool string_to_binary = true) { Protocol prot; prot.setInput(&buf); MaskedDecodeResult result; MaskedProtocolData& protocolData = result.excluded; protocolData.protocol() = get_standard_protocol; MaskedDecodeResultValue parseValueResult = parseValueWithMask( prot, T_STRUCT, MaskRef{readMask, false}, MaskRef{writeMask, false}, protocolData, string_to_binary); protocolData.data() = std::move(parseValueResult.excluded); // Calling ensure as it is possible that the value is not set. result.included = std::move(parseValueResult.included.ensure_object()); return result; } } // namespace detail } // namespace apache::thrift::protocol