#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pyre-strict from __future__ import annotations import pickle import types import unittest from typing import cast, Type, TypeVar import python_test.enums.thrift_mutable_types as mutable_types import python_test.enums.thrift_types as immutable_types import thrift.python.mutable_serializer as mutable_serializer import thrift.python.serializer as immutable_serializer from parameterized import parameterized_class from python_test.enums.thrift_types import ( BadMembers, Color, ColorGroups, File, Kind, OptionalColorGroups, OptionalFile, Perm, ) from thrift.python.types import BadEnum, Enum, Flag _E = TypeVar("_E", bound=Enum) # tt = test_types, ser = serializer @parameterized_class( ("test_types", "serializer_module"), [ (immutable_types, immutable_serializer), (mutable_types, mutable_serializer), ], ) class EnumTests(unittest.TestCase): def setUp(self) -> None: # pyre-ignore[16]: has no attribute `test_types` self.BadMembers: Type[BadMembers] = self.test_types.BadMembers self.File: Type[File] = self.test_types.File self.Kind: Type[Kind] = self.test_types.Kind self.OptionalFile: Type[OptionalFile] = self.test_types.OptionalFile self.ColorGroups: Type[ColorGroups] = self.test_types.ColorGroups self.Color: Type[Color] = self.test_types.Color self.Perm: Type[Perm] = self.test_types.Perm self.OptionalColorGroups: Type[OptionalColorGroups] = ( self.test_types.OptionalColorGroups ) self.is_mutable_run: bool = self.test_types.__name__.endswith( "thrift_mutable_types" ) # pyre-ignore[16]: has no attribute `serializer_module` self.serializer: types.ModuleType = self.serializer_module def test_bad_member_names(self) -> None: self.assertIsInstance(self.BadMembers.name_, self.BadMembers) self.assertIsInstance(self.BadMembers.value_, self.BadMembers) self.assertIn("name_", self.BadMembers.__members__) self.assertIn("value_", self.BadMembers.__members__) def test_normal_enum(self) -> None: with self.assertRaises(TypeError): # Enums are not ints # pyre-ignore[6]: for tests self.File(name="/etc/motd", type=8) x = self.File(name="/etc", type=self.Kind.DIR) self.assertIsInstance(x.type, self.Kind) self.assertEqual(x.type, self.Kind.DIR) self.assertNotEqual(x.type, self.Kind.SOCK) self.assertNotIsInstance(4, self.Kind, "Ints are not Enums") self.assertIsInstance(self.Kind.DIR, int, "Enums are Ints") self.assertIn(x.type, self.Kind) self.assertEqual(x.type.value, 4) def test_enum_value_rename(self) -> None: """The value name is None but we auto rename it to None_""" x = self.serializer.deserialize( self.File, b'{"name":"blah", "type":0}', self.serializer.Protocol.JSON ) self.assertEqual(x.type, self.Kind.None_) def test_protocol_int_conversion(self) -> None: self.assertEqual(self.serializer.Protocol.BINARY.value, 0) self.assertEqual(self.serializer.Protocol.DEPRECATED_VERBOSE_JSON.value, 1) self.assertEqual(self.serializer.Protocol.COMPACT.value, 2) self.assertEqual(self.serializer.Protocol.JSON.value, 5) def test_bad_enum_hash_same(self) -> None: x = self.serializer.deserialize( self.File, b'{"name": "something", "type": 64}', self.serializer.Protocol.JSON, ) y = self.serializer.deserialize( self.File, b'{"name": "something", "type": 64}', self.serializer.Protocol.JSON, ) # Mutable types do not support hashing if not self.is_mutable_run: self.assertEqual(hash(x), hash(y)) self.assertEqual(hash(x.type), hash(y.type)) self.assertFalse(x.type is y.type) self.assertEqual(x.type, y.type) self.assertFalse(x.type != y.type) def test_bad_enum_in_struct(self) -> None: to_serialize = self.OptionalFile(name="something", type=64) serialized = self.serializer.serialize_iobuf(to_serialize) x = self.serializer.deserialize(self.File, serialized) self.assertBadEnum(cast(BadEnum, x.type), self.Kind, 64) def test_bad_enum_in_list_index(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_list=[1, 5, 0]) ), ) self.assertEqual(len(x.color_list), 3) self.assertEqual(x.color_list[0], self.Color.blue) self.assertBadEnum(cast(BadEnum, x.color_list[1]), self.Color, 5) self.assertEqual(x.color_list[2], self.Color.red) def test_bad_enum_in_list_iter(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_list=[1, 5, 0]) ), ) for idx, v in enumerate(x.color_list): if idx == 0: self.assertEqual(v, self.Color.blue) elif idx == 1: self.assertBadEnum(cast(BadEnum, v), self.Color, 5) else: self.assertEqual(v, self.Color.red) def test_bad_enum_in_list_reverse(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_list=[1, 5, 0]) ), ) for idx, v in enumerate(reversed(x.color_list)): if idx == 0: self.assertEqual(v, self.Color.red) elif idx == 1: self.assertBadEnum(cast(BadEnum, v), self.Color, 5) else: self.assertEqual(v, self.Color.blue) def test_bad_enum_in_set_iter(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_list=[1, 5, 0]) ), ) for v in x.color_set: if v not in (self.Color.blue, self.Color.red): self.assertBadEnum(cast(BadEnum, v), self.Color, 5) def test_bad_enum_in_map_lookup(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_map={1: 2, 0: 5, 6: 1, 7: 8}) ), ) val = x.color_map[self.Color.red] self.assertBadEnum(cast(BadEnum, val), self.Color, 5) def test_bad_enum_in_map_iter(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_map={1: 2, 0: 5, 6: 1, 7: 8}) ), ) s = set() for k in x.color_map: s.add(k) self.assertEqual(len(s), 4) s.discard(self.Color.red) s.discard(self.Color.blue) lst = sorted(s, key=lambda e: cast(BadEnum, e).value) self.assertBadEnum(cast(BadEnum, lst[0]), self.Color, 6) self.assertBadEnum(cast(BadEnum, lst[1]), self.Color, 7) def test_bad_enum_in_map_values(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_map={1: 2, 0: 5, 6: 1, 7: 8}) ), ) s = set() for k in x.color_map.values(): s.add(k) self.assertEqual(len(s), 4) s.discard(self.Color.green) s.discard(self.Color.blue) lst = sorted(s, key=lambda e: cast(BadEnum, e).value) self.assertBadEnum(cast(BadEnum, lst[0]), self.Color, 5) self.assertBadEnum(cast(BadEnum, lst[1]), self.Color, 8) def test_bad_enum_in_map_items(self) -> None: x = self.serializer.deserialize( self.ColorGroups, self.serializer.serialize_iobuf( self.OptionalColorGroups(color_map={1: 2, 0: 5, 6: 1, 7: 8}) ), ) for k, v in x.color_map.items(): if k == self.Color.blue: self.assertEqual(v, self.Color.green) elif k == self.Color.red: self.assertBadEnum(cast(BadEnum, v), self.Color, 5) else: ck = cast(BadEnum, k) if ck.value == 6: self.assertEqual(v, self.Color.blue) else: self.assertBadEnum(cast(BadEnum, v), self.Color, 8) def assertBadEnum(self, e: BadEnum, cls: Type[_E], val: int) -> None: self.assertIsInstance(e, BadEnum) self.assertEqual(e.value, val) self.assertEqual(e.enum, cls) self.assertEqual(int(e), val) def test_format(self) -> None: self.assertEqual(f"{self.Color.red}", "Color.red") def test_bool_of_class(self) -> None: self.assertTrue(bool(self.Color)) def test_bool_of_members(self) -> None: self.assertTrue(self.Kind.None_) self.assertTrue(self.Color.red) def test_pickle(self) -> None: serialized = pickle.dumps(self.Color.green) green = pickle.loads(serialized) self.assertIs(green, self.Color.green) def test_adding_member(self) -> None: with self.assertRaises(AttributeError): # pyre-fixme[16]: `Type` has no attribute `black`. self.Color.black = 3 def test_delete(self) -> None: with self.assertRaises(AttributeError): del self.Color.red def test_changing_member(self) -> None: with self.assertRaises(AttributeError): # pyre-fixme[8]: Attribute has type `Color`; used as `str`. self.Color.red = "lol" def test_contains(self) -> None: self.assertIn(self.Color.blue, self.Color) self.assertIn(1, self.Color) def test_equal(self) -> None: self.assertEqual(self.Color.blue, self.Color.blue) self.assertNotEqual(self.Color.blue, self.Color.green) self.assertEqual(self.Color.blue, 1) self.assertEqual(2, self.Color.green) self.assertNotEqual(self.Color.blue, self.Kind.FIFO) def test_hash(self) -> None: colors = {} colors[self.Color.red] = 0xFF0000 colors[self.Color.blue] = 0x0000FF colors[self.Color.green] = 0x00FF00 self.assertEqual(colors[self.Color.green], 0x00FF00) self.assertTrue(self.Color.blue in colors) self.assertTrue(self.Kind.CHAR not in colors) self.assertTrue(1 in colors) values_to_names = {v.value: v.name for v in self.Color} self.assertEqual(values_to_names[self.Color.red], "red") def test_enum_in_enum_out(self) -> None: self.assertIs(self.Color(self.Color.blue), self.Color.blue) def test_enum_value(self) -> None: self.assertEqual(self.Color.red.value, 0) def test_enum(self) -> None: lst = list(self.Color) self.assertEqual(len(lst), len(self.Color)) self.assertEqual(len(self.Color), 3) self.assertEqual([self.Color.red, self.Color.blue, self.Color.green], lst) for i, color in enumerate("red blue green".split(), 0): e = self.Color(i) self.assertEqual(e, getattr(self.Color, color)) self.assertEqual(e.value, i) self.assertEqual(e, i) self.assertEqual(e.name, color) self.assertIn(e, self.Color) self.assertIs(type(e), self.Color) self.assertIsInstance(e, self.Color) self.assertEqual(str(e), "Color." + color) self.assertEqual(int(e), i) self.assertEqual(repr(e), f"") def test_insinstance_Enum(self) -> None: _ = list(self.Color) self.assertIsInstance(self.Color.red, Enum) self.assertTrue(issubclass(self.Color, Enum)) # tt = test_types, ser = serializer @parameterized_class( ("test_types", "serializer_module"), [ (immutable_types, immutable_serializer), (mutable_types, mutable_serializer), ], ) class FlagTests(unittest.TestCase): def setUp(self) -> None: # pyre-ignore[16]: has no attribute `test_types` self.BadMembers: Type[BadMembers] = self.test_types.BadMembers self.File: Type[File] = self.test_types.File self.Kind: Type[Kind] = self.test_types.Kind self.OptionalFile: Type[OptionalFile] = self.test_types.OptionalFile self.ColorGroups: Type[ColorGroups] = self.test_types.ColorGroups self.Color: Type[Color] = self.test_types.Color self.Perm: Type[Perm] = self.test_types.Perm self.OptionalColorGroups: Type[OptionalColorGroups] = ( self.test_types.OptionalColorGroups ) # pyre-ignore[16]: has no attribute `serializer_module` self.serializer: types.ModuleType = self.serializer_module def test_flag_enum(self) -> None: with self.assertRaises(TypeError): # pyre-ignore[6]: for tests self.File(name="/etc/motd", permissions=4) x = self.File(name="/bin/sh", permissions=self.Perm.read | self.Perm.execute) self.assertIsInstance(x.permissions, self.Perm) self.assertEqual(x.permissions, self.Perm.read | self.Perm.execute) self.assertTrue(x.permissions) self.assertNotIsInstance(2, self.Perm, "Flags are not ints") self.assertEqual(x.permissions.value, 5) x = self.File(name="") self.assertFalse(x.permissions) self.assertIsInstance(x.permissions, self.Perm) self.assertEqual(f"{self.Perm.read}", "Perm.read") self.assertTrue(self.Perm.read in self.Perm.read | self.Perm.execute) def test_flag_enum_serialization_roundtrip(self) -> None: x = self.File( name="/dev/null", type=self.Kind.CHAR, permissions=self.Perm.read | self.Perm.write, ) y = self.serializer.deserialize(self.File, self.serializer.serialize_iobuf(x)) self.assertEqual(x, y) self.assertEqual(x.permissions, self.Perm.read | self.Perm.write) self.assertIsInstance(x.permissions, self.Perm) def test_zero(self) -> None: zero = self.Perm(0) self.assertNotIn(zero, self.Perm) self.assertIsInstance(zero, self.Perm) def test_logical(self) -> None: self.assertEqual(self.Perm.read & self.Perm.write, self.Perm(0)) self.assertEqual(self.Perm.read ^ self.Perm.write, self.Perm(6)) self.assertEqual(~self.Perm.read, self.Perm(3)) def test_combination(self) -> None: combo = self.Perm(self.Perm.read.value | self.Perm.execute.value) self.assertNotIn(combo, self.Perm) self.assertIsInstance(combo, self.Perm) self.assertIs(combo, self.Perm.read | self.Perm.execute) def test_is(self) -> None: allp = self.Perm(7) self.assertIs(allp, self.Perm(7)) def test_invert(self) -> None: x = self.Perm(-2) self.assertIs(x, self.Perm.read | self.Perm.write) def test_insinstance_Flag(self) -> None: self.assertIsInstance(self.Perm.read, Flag) self.assertTrue(issubclass(self.Perm, Flag)) self.assertIsInstance(self.Perm.read, Enum) self.assertTrue(issubclass(self.Perm, Enum)) def test_combo_in_call(self) -> None: x = self.Perm(7) self.assertIs(x, self.Perm.read | self.Perm.write | self.Perm.execute) def test_combo_repr(self) -> None: x = self.Perm(7) self.assertEqual("", repr(x))