/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include #include namespace apache::thrift::protocol::detail { // Validates the mask with the given Struct. Ensures that mask doesn't contain // fields not in the Struct. template bool validate_mask(MaskRef ref) { // Get the field ids in the thrift struct type. std::unordered_set ids; ids.reserve(op::size_v); op::for_each_ordinal( [&](auto ord) { ids.insert(op::get_field_id()); }); const FieldIdToMask& map = ref.mask.includes_ref() ? ref.mask.includes_ref().value() : ref.mask.excludes_ref().value(); for (auto& [id, _] : map) { // Mask contains a field not in the struct. if (ids.find(FieldId{id}) == ids.end()) { return false; } } return true; } // Validates the fields in the Struct with the MaskRef. template bool validate_fields(MaskRef ref) { if (!validate_mask(ref)) { return false; } // Validates each field in the struct. bool isValid = true; op::for_each_ordinal([&](auto ord) { if (!isValid) { // short circuit return; } using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); if (next.isAllMask() || next.isNoneMask()) { return; } // Check if the field is a thrift struct type. It uses native_type // as we don't support adapted struct fields in field mask. using FieldType = op::get_native_type; if constexpr (is_thrift_struct_v) { // Need to validate the struct type. isValid &= validate_fields(next); return; } isValid = false; }); return isValid; } template bool is_compatible_with(const Mask&); template bool is_compatible_with_impl(Tag, const Mask&) { return false; } template bool is_compatible_with_impl(type::struct_t, const Mask& mask) { return validate_fields({mask}); } template bool is_compatible_with_impl(type::map, const Mask& mask) { // Map mask is compatible only if all nested masks are compatible with // `Value`. if (const auto* m = getIntegerMapMask(mask)) { return std::all_of(m->begin(), m->end(), [](const auto& pair) { return is_compatible_with(pair.second); }); } if (const auto* m = getStringMapMask(mask)) { return std::all_of(m->begin(), m->end(), [](const auto& pair) { return is_compatible_with(pair.second); }); } return true; } template bool is_compatible_with(const Mask& mask) { if (isAllMask(mask) || isNoneMask(mask)) { return true; } return is_compatible_with_impl(Tag{}, mask); } // Throws an error if a thrift struct type is not compatible with the mask. template void errorIfNotCompatible(const Mask& mask) { if (!::apache::thrift::protocol::detail::is_compatible_with(mask)) { folly::throw_exception( "The mask and struct are incompatible."); } } // This converts id list to a field mask with a single field. template Mask path(const Mask& other) { // This is the base case as there is no more id. errorIfNotCompatible(other); return other; } template Mask path(const Mask& other) { using Struct = type::native_type; static_assert(is_thrift_struct_v); Mask mask; using fieldId = op::get_field_id; static_assert(fieldId::value != FieldId{}); mask.includes_ref().emplace()[static_cast(fieldId::value)] = path, Ids...>(other); return mask; } // This converts field name list from the given index to a field mask with a // single field. template Mask path( const std::vector& fieldNames, size_t index, const Mask& other) { if (index == fieldNames.size()) { errorIfNotCompatible(other); return other; } // static_assert doesn't work as it compiles this code for every field. using Struct = type::native_type; if constexpr (is_thrift_struct_v) { Mask mask; op::for_each_field_id([&](auto id) { using Id = decltype(id); if (mask.includes_ref()) { // already set return; } if (op::get_name_v == fieldNames[index]) { mask.includes_ref().emplace()[folly::to_underlying(id())] = path>(fieldNames, index + 1, other); } }); if (!mask.includes_ref()) { // field not found folly::throw_exception("field doesn't exist"); } return mask; } folly::throw_exception( "Path contains a non thrift struct field."); } // Ensures the masked fields in the given thrift struct. template void ensure_fields(MaskRef ref, Struct& t) { if (!validate_mask(ref)) { folly::throw_exception( "The mask and struct are incompatible."); } if constexpr (!std::is_const_v>) { op::for_each_ordinal([&](auto ord) { using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); if (next.isNoneMask()) { return; } using FieldTag = op::get_field_tag; auto&& field_ref = op::get(t); op::ensure(field_ref, t); // Need to ensure the struct object. using FieldType = op::get_native_type; if constexpr (is_thrift_struct_v) { auto& value = *op::getValueOrNull(field_ref); ensure_fields(next, value); return; } if (!next.isAllMask()) { folly::throw_exception( "The mask and struct are incompatible."); } }); } else { folly::throw_exception("Cannot ensure a const object"); } } // Clears the masked fields in the given thrift struct. template void clear_fields(MaskRef ref, Struct& t) { if (!validate_mask(ref)) { folly::throw_exception( "The mask and struct are incompatible."); } if constexpr (!std::is_const_v>) { op::for_each_ordinal([&](auto ord) { using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); if (next.isNoneMask()) { return; } using FieldTag = op::get_field_tag; auto&& field_ref = op::get(t); if (next.isAllMask()) { op::clear_field(field_ref, t); return; } using FieldType = op::get_native_type; auto* field_value = op::getValueOrNull(field_ref); if (!field_value) { errorIfNotCompatible>(next.mask); return; } // Need to clear the struct object. if constexpr (is_thrift_struct_v) { clear_fields(next, *field_value); return; } folly::throw_exception( "The mask and struct are incompatible."); }); } else { folly::throw_exception("Cannot clear a const object"); } } // Copies the masked fields from src thrift struct to dst. // Returns true if it copied a field from src to dst. template bool copy_fields(MaskRef ref, SrcStruct& src, DstStruct& dst) { static_assert(std::is_same_v< folly::remove_cvref_t, folly::remove_cvref_t>); if (!validate_mask(ref)) { folly::throw_exception( "The mask and struct are incompatible."); } if constexpr (!std::is_const_v>) { bool copied = false; op::for_each_ordinal([&](auto ord) { using Ord = decltype(ord); MaskRef next = ref.get(op::get_field_id()); // Id doesn't exist in field mask, skip. if (next.isNoneMask()) { return; } using FieldTag = op::get_field_tag; using FieldType = op::get_native_type; auto&& src_ref = op::get(src); auto&& dst_ref = op::get(dst); bool srcHasValue = bool(op::getValueOrNull(src_ref)); bool dstHasValue = bool(op::getValueOrNull(dst_ref)); if (!srcHasValue && !dstHasValue) { // skip errorIfNotCompatible>(next.mask); return; } // Id that we want to copy. if (next.isAllMask()) { if (srcHasValue) { op::copy(src_ref, dst_ref); copied = true; } else { op::clear_field(dst_ref, dst); } return; } if constexpr (is_thrift_struct_v) { // Field doesn't exist in src, so just clear dst with the mask. if (!srcHasValue) { clear_fields(next, *op::getValueOrNull(dst_ref)); return; } // Field exists in both src and dst, so call copy recursively. if (dstHasValue) { copied |= copy_fields( next, *op::getValueOrNull(src_ref), *op::getValueOrNull(dst_ref)); return; } // Field only exists in src. Need to construct object only if there's // a field to add. FieldType newObject; bool constructObject = copy_fields(next, *op::getValueOrNull(src_ref), newObject); if (constructObject) { moveObject(dst_ref, std::move(newObject)); copied = true; } return; } folly::throw_exception( "The mask and struct are incompatible."); }); return copied; } else { folly::throw_exception("Cannot copy to a const field"); } } } // namespace apache::thrift::protocol::detail