-
-
Notifications
You must be signed in to change notification settings - Fork 437
Expand file tree
/
Copy pathenum.py
More file actions
160 lines (127 loc) · 5.18 KB
/
enum.py
File metadata and controls
160 lines (127 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Enumeration model generator.
Provides Enum, StrEnum, and specialized enum classes for code generation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from datamodel_code_generator.imports import IMPORT_ANY, IMPORT_ENUM, IMPORT_INT_ENUM, IMPORT_STR_ENUM, Import
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED, BaseClassDataType
from datamodel_code_generator.types import DataType, Types
if TYPE_CHECKING:
from collections import defaultdict
from pathlib import Path
from datamodel_code_generator.reference import Reference
_INT: str = "int"
_FLOAT: str = "float"
_BYTES: str = "bytes"
_STR: str = "str"
SUBCLASS_BASE_CLASSES: dict[Types, str] = {
Types.int32: _INT,
Types.int64: _INT,
Types.integer: _INT,
Types.float: _FLOAT,
Types.double: _FLOAT,
Types.number: _FLOAT,
Types.byte: _BYTES,
Types.string: _STR,
}
class Enum(DataModel):
"""DataModel implementation for Python enumerations."""
TEMPLATE_FILE_PATH: ClassVar[str] = "Enum.jinja2"
BASE_CLASS: ClassVar[str] = "enum.Enum"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_ENUM,)
SUPPORTS_GENERIC_BASE_CLASS: ClassVar[bool] = False
def __init__( # noqa: PLR0913
self,
*,
reference: Reference,
fields: list[DataModelFieldBase],
decorators: list[str] | None = None,
base_classes: list[Reference] | None = None,
custom_base_class: str | list[str] | None = None,
custom_template_dir: Path | None = None,
extra_template_data: defaultdict[str, dict[str, Any]] | None = None,
methods: list[str] | None = None,
path: Path | None = None,
description: str | None = None,
type_: Types | None = None,
default: Any = UNDEFINED,
nullable: bool = False,
keyword_only: bool = False,
treat_dot_as_module: bool | None = None,
) -> None:
"""Initialize Enum with optional specialized base class based on type."""
super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
default=default,
nullable=nullable,
keyword_only=keyword_only,
treat_dot_as_module=treat_dot_as_module,
)
if not base_classes and type_ and (base_class := SUBCLASS_BASE_CLASSES.get(type_)):
self.base_classes: list[BaseClassDataType] = [
BaseClassDataType(type=base_class),
*self.base_classes,
]
@classmethod
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
"""Get data type for enum (not implemented)."""
raise NotImplementedError
def get_member(self, field: DataModelFieldBase) -> Member:
"""Create a Member instance for the given field."""
return Member(self, field)
def find_member(self, value: Any) -> Member | None:
"""Find enum member matching the given value."""
repr_value = repr(value)
# Remove surrounding quotes from the string representation
str_value = str(value).strip("'\"")
for field in self.fields:
# Remove surrounding quotes from field default value
field_default = str(field.default or "").strip("'\"")
# Compare values after removing quotes
if field_default == str_value:
return self.get_member(field)
# Keep original comparison for backwards compatibility
if field.default == repr_value: # pragma: no cover
return self.get_member(field)
return None
@property
def imports(self) -> tuple[Import, ...]:
"""Get imports excluding Any."""
return tuple(i for i in super().imports if i != IMPORT_ANY)
class StrEnum(Enum):
"""String enumeration type."""
BASE_CLASS: ClassVar[str] = "enum.StrEnum"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_STR_ENUM,)
class IntEnum(Enum):
"""Integer enumeration type."""
BASE_CLASS: ClassVar[str] = "enum.IntEnum"
DEFAULT_IMPORTS: ClassVar[tuple[Import, ...]] = (IMPORT_INT_ENUM,)
SPECIALIZED_ENUM_TYPE_MATCH: dict[Types, type[Enum]] = {
Types.int32: IntEnum,
Types.int64: IntEnum,
Types.integer: IntEnum,
Types.string: StrEnum,
}
"""
Map specialized enum types to their corresponding Enum subclasses.
"""
class Member:
"""Represents an enum member with its parent enum and field."""
def __init__(self, enum: Enum, field: DataModelFieldBase) -> None:
"""Initialize enum member."""
self.enum: Enum = enum
self.field: DataModelFieldBase = field
self.alias: Optional[str] = None # noqa: UP045
def __repr__(self) -> str:
"""Return string representation of enum member."""
return f"{self.alias or self.enum.class_name}.{self.field.name}"