/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include namespace field_mask_constants = apache::thrift::protocol::field_mask_constants; namespace apache::thrift::protocol::detail { const FieldIdToMask* FOLLY_NULLABLE getFieldMask(const Mask& mask) { if (mask.includes_ref()) { return &*mask.includes_ref(); } if (mask.excludes_ref()) { return &*mask.excludes_ref(); } return nullptr; } const MapIdToMask* FOLLY_NULLABLE getIntegerMapMask(const Mask& mask) { if (mask.includes_map_ref()) { return &*mask.includes_map_ref(); } if (mask.excludes_map_ref()) { return &*mask.excludes_map_ref(); } return nullptr; } const MapStringToMask* FOLLY_NULLABLE getStringMapMask(const Mask& mask) { if (mask.includes_string_map_ref()) { return &*mask.includes_string_map_ref(); } if (mask.excludes_string_map_ref()) { return &*mask.excludes_string_map_ref(); } return nullptr; } [[nodiscard]] const MapTypeToMask* FOLLY_NULLABLE getTypeMapMask(const Mask& mask) { if (mask.includes_type_ref()) { return &*mask.includes_type_ref(); } if (mask.excludes_type_ref()) { return &*mask.excludes_type_ref(); } return nullptr; } ArrayKey getArrayKeyFromValue(const Value& v) { if (v.is_byte() || v.is_i16() || v.is_i32() || v.is_i64()) { return ArrayKey::Integer; } if (v.is_binary() || v.is_string()) { return ArrayKey::String; } folly::throw_exception( "Value contains a non-integer or non-string key."); } MapId getMapIdFromValue(const Value& v) { if (v.is_byte()) { return MapId{v.as_byte()}; } if (v.is_i16()) { return MapId{v.as_i16()}; } if (v.is_i32()) { return MapId{v.as_i32()}; } if (v.is_i64()) { return MapId{v.as_i64()}; } folly::throw_exception( "Value contains a non-integer key."); } std::string getStringFromValue(const Value& v) { if (v.is_binary()) { return v.as_binary().to(); } if (v.is_string()) { return v.as_string(); } folly::throw_exception( "Value contains a non-string key."); } void throwIfContainsMapMask(const Mask& mask) { if (mask.includes_map_ref() || mask.excludes_map_ref() || mask.includes_string_map_ref() || mask.excludes_string_map_ref()) { folly::throw_exception("map mask is not implemented"); } if (auto* typeMapPtr = getTypeMapMask(mask)) { for (const auto& [_, nestedMask] : *typeMapPtr) { throwIfContainsMapMask(nestedMask); } return; } for (const auto& [_, nestedMask] : *CHECK_NOTNULL(getFieldMask(mask))) { throwIfContainsMapMask(nestedMask); } } MapId findMapIdByValueAddress(const Mask& mask, const Value& newKey) { MapId mapId = MapId{reinterpret_cast(&newKey)}; if (!(mask.includes_map_ref() || mask.excludes_map_ref())) { return mapId; } const auto& mapIdToMask = mask.includes_map_ref() ? *mask.includes_map_ref() : *mask.excludes_map_ref(); auto it = std::find_if( mapIdToMask.begin(), mapIdToMask.end(), [&newKey](const auto& kv) { return *(reinterpret_cast(kv.first)) == newKey; }); return it == mapIdToMask.end() ? mapId : MapId{it->first}; } MapId getMapIdValueAddressFromIndex( const ValueIndex& index, const Value& newKey) { if (auto it = index.find(newKey); it != index.end()) { return MapId{reinterpret_cast(&(it->get()))}; } return MapId{reinterpret_cast(&newKey)}; } ValueIndex buildValueIndex(const Mask& mask) { ValueIndex index; const auto* mapIdToMask = getIntegerMapMask(mask); if (!mapIdToMask) { return index; } index.reserve(mapIdToMask->size()); for (auto& [key, _] : *mapIdToMask) { index.insert(std::cref(*reinterpret_cast(key))); } return index; } } // namespace apache::thrift::protocol::detail