From 30ead6fff8c8e1c1ee4ef3ff6e4ba0f40e2291dc Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 01:33:16 -0300 Subject: [PATCH 01/13] refactor: split Table.__init_subclass__ responsibilities into helper methods (R0912) --- piccolo/table.py | 91 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/piccolo/table.py b/piccolo/table.py index 8ff955d59..4e9bbc4b0 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -236,6 +236,21 @@ def session_auth( return cls.__name__ +@dataclass +class _ClassifiedColumns: + columns: list[Column] = field(default_factory=list) + default_columns: list[Column] = field(default_factory=list) + non_default_columns: list[Column] = field(default_factory=list) + array_columns: list[Array] = field(default_factory=list) + foreign_key_columns: list[ForeignKey] = field(default_factory=list) + secret_columns: list[Column] = field(default_factory=list) + json_columns: list[Union[JSON, JSONB]] = field(default_factory=list) + email_columns: list[Email] = field(default_factory=list) + auto_update_columns: list[Column] = field(default_factory=list) + primary_key: Optional[Column] = None + m2m_relationships: list[M2M] = field(default_factory=list) + + class Table(metaclass=TableMetaclass): """ The class represents a database table. An instance represents a row. @@ -276,6 +291,46 @@ def __init_subclass__( """ if tags is None: tags = [] + tablename, schema = cls._process_tablename(tablename, schema) + classified = cls._classify_columns() + + if not classified.primary_key: + primary_key = cls._create_serial_primary_key() + setattr(cls, "id", primary_key) + classified.columns.insert(0, primary_key) + classified.default_columns.append(primary_key) + else: + primary_key = classified.primary_key + + cls._meta = TableMeta( + tablename=tablename, + columns=classified.columns, + default_columns=classified.default_columns, + non_default_columns=classified.non_default_columns, + array_columns=classified.array_columns, + email_columns=classified.email_columns, + primary_key=primary_key, + foreign_key_columns=classified.foreign_key_columns, + json_columns=classified.json_columns, + secret_columns=classified.secret_columns, + auto_update_columns=classified.auto_update_columns, + tags=tags, + help_text=help_text, + _db=db, + m2m_relationships=classified.m2m_relationships, + schema=schema, + ) + + cls._setup_foreign_keys() + + TABLE_REGISTRY.append(cls) + + @classmethod + def _process_tablename( + cls, + tablename: Optional[str], + schema: Optional[str], + ) -> tuple[str, Optional[str]]: tablename = tablename or _camel_to_snake(cls.__name__) if "." in tablename: @@ -288,6 +343,10 @@ def __init_subclass__( if tablename in PROTECTED_TABLENAMES: warnings.warn(TABLENAME_WARNING.format(tablename=tablename)) + return tablename, schema + + @classmethod + def _classify_columns(cls) -> _ClassifiedColumns: columns: list[Column] = [] default_columns: list[Column] = [] non_default_columns: list[Column] = [] @@ -311,10 +370,6 @@ def __init_subclass__( attribute = getattr(cls, attribute_name) if isinstance(attribute, Column): - # We have to copy, then override the existing column - # definition, in case this column is inheritted from a mixin. - # Otherwise, when we set attributes on that column, it will - # effect all other users of that mixin. column = attribute.copy() setattr(cls, attribute_name, column) @@ -351,35 +406,23 @@ def __init_subclass__( attribute._meta._table = cls m2m_relationships.append(attribute) - if not primary_key: - primary_key = cls._create_serial_primary_key() - setattr(cls, "id", primary_key) - - columns.insert(0, primary_key) # PK should be added first - default_columns.append(primary_key) - - cls._meta = TableMeta( - tablename=tablename, + return _ClassifiedColumns( columns=columns, default_columns=default_columns, non_default_columns=non_default_columns, array_columns=array_columns, - email_columns=email_columns, - primary_key=primary_key, foreign_key_columns=foreign_key_columns, - json_columns=json_columns, secret_columns=secret_columns, + json_columns=json_columns, + email_columns=email_columns, auto_update_columns=auto_update_columns, - tags=tags, - help_text=help_text, - _db=db, + primary_key=primary_key, m2m_relationships=m2m_relationships, - schema=schema, ) - for foreign_key_column in foreign_key_columns: - # ForeignKey columns require additional setup based on their - # parent Table. + @classmethod + def _setup_foreign_keys(cls) -> None: + for foreign_key_column in cls._meta.foreign_key_columns: foreign_key_setup_response = foreign_key_column._setup( table_class=cls ) @@ -388,8 +431,6 @@ def __init_subclass__( foreign_key_column ) - TABLE_REGISTRY.append(cls) - def __init__( self, _data: Optional[dict[Column, Any]] = None, From 2ab626511552f1037b1c7aaed07820fe7e539adb Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 01:53:40 -0300 Subject: [PATCH 02/13] refactor: split Query._process_results logic into helper methods (R0912) --- piccolo/query/base.py | 122 ++++++++++++++++++++++-------------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/piccolo/query/base.py b/piccolo/query/base.py index daa894efc..b3583482b 100644 --- a/piccolo/query/base.py +++ b/piccolo/query/base.py @@ -56,77 +56,83 @@ async def _process_results(self, results) -> QueryResponseType: if hasattr(self, "_raw_response_callback"): self._raw_response_callback(raw) + raw = self._load_json_columns(raw) + raw = await self.response_handler(raw) + return self._process_output(raw) + + def _load_json_columns(self, raw: list[dict]) -> list[dict]: output: Optional[OutputDelegate] = getattr( self, "output_delegate", None ) + if not (output and output._output.load_json): + return raw - ####################################################################### + columns_delegate: Optional[ColumnsDelegate] = getattr( + self, "columns_delegate", None + ) - if output and output._output.load_json: - columns_delegate: Optional[ColumnsDelegate] = getattr( - self, "columns_delegate", None - ) + json_column_names: list[str] = [] - json_column_names: list[str] = [] + if columns_delegate is not None: + json_columns: list[Union[JSON, JSONB]] = [] - if columns_delegate is not None: - json_columns: list[Union[JSON, JSONB]] = [] + for column in columns_delegate.selected_columns: + if isinstance(column, (JSON, JSONB)): + json_columns.append(column) + elif isinstance(column, JSONQueryString): + if alias := column._alias: + json_column_names.append(alias) + else: + json_columns = self.table._meta.json_columns - for column in columns_delegate.selected_columns: - if isinstance(column, (JSON, JSONB)): - json_columns.append(column) - elif isinstance(column, JSONQueryString): - if alias := column._alias: - json_column_names.append(alias) + for column in json_columns: + if column._alias is not None: + json_column_names.append(column._alias) + elif len(column._meta.call_chain) > 0: + json_column_names.append(column._meta.get_default_alias()) else: - json_columns = self.table._meta.json_columns - - for column in json_columns: - if column._alias is not None: - json_column_names.append(column._alias) - elif len(column._meta.call_chain) > 0: - json_column_names.append(column._meta.get_default_alias()) - else: - json_column_names.append(column._meta.name) + json_column_names.append(column._meta.name) - processed_raw = [] + processed_raw = [] - for row in raw: - new_row = {**row} - for json_column_name in json_column_names: - value = new_row.get(json_column_name) - if value is not None: - new_row[json_column_name] = load_json(value) - processed_raw.append(new_row) + for row in raw: + new_row = {**row} + for json_column_name in json_column_names: + value = new_row.get(json_column_name) + if value is not None: + new_row[json_column_name] = load_json(value) + processed_raw.append(new_row) - raw = processed_raw + return processed_raw - ####################################################################### - - raw = await self.response_handler(raw) - - if output: - if output._output.as_objects: - if output._output.nested: - return cast( - QueryResponseType, - [ - make_nested_object( - row, - self.table, - load_json=output._output.load_json, - ) - for row in raw - ], - ) - else: - return cast( - QueryResponseType, - [ - self.table(**columns, _exists_in_db=True) - for columns in raw - ], - ) + def _process_output(self, raw: list[dict]) -> QueryResponseType: + output: Optional[OutputDelegate] = getattr( + self, "output_delegate", None + ) + if not output: + return cast(QueryResponseType, raw) + + if output._output.as_objects: + if output._output.nested: + return cast( + QueryResponseType, + [ + make_nested_object( + row, + self.table, + load_json=output._output.load_json, + ) + for row in raw + ], + ) + else: + return cast( + QueryResponseType, + [ + self.table(**columns, _exists_in_db=True) + for columns in raw + ], + ) return cast(QueryResponseType, raw) From 7bd6253e2ca9e447c25af909257411b2cd93a5a5 Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 02:05:36 -0300 Subject: [PATCH 03/13] refactor: split Select.response_handler M2M logic into helpers (R0912) --- piccolo/query/methods/select.py | 183 ++++++++++++++++---------------- 1 file changed, 92 insertions(+), 91 deletions(-) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 4266c3c7d..5c3ef6f7f 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -295,6 +295,93 @@ async def _splice_m2m_rows( row[m2m_name] = [extra_rows_map.get(i) for i in row[m2m_name]] return response + async def _handle_m2m_sqlite( + self, response: list, m2m_select: M2MSelect + ) -> list: + m2m_name = m2m_select.m2m._meta.name + secondary_table = m2m_select.m2m._meta.secondary_table + secondary_table_pk = secondary_table._meta.primary_key + + value_type = ( + m2m_select.columns[0].__class__.value_type + if m2m_select.as_list and m2m_select.serialisation_safe + else secondary_table_pk.value_type + ) + try: + for row in response: + data = row[m2m_name] + row[m2m_name] = ( + [value_type(i) for i in row[m2m_name]] + if data + else [] + ) + except ValueError: + colored_warning( + "Unable to do type conversion for the " + f"{m2m_name} relation" + ) + + if m2m_select.as_list: + if not m2m_select.serialisation_safe: + response = await self._splice_m2m_rows( + response, + secondary_table, + secondary_table_pk, + m2m_name, + m2m_select, + as_list=True, + ) + else: + if ( + len(m2m_select.columns) == 1 + and m2m_select.serialisation_safe + ): + column_name = m2m_select.columns[0]._meta.name + for row in response: + row[m2m_name] = [ + {column_name: i} for i in row[m2m_name] + ] + else: + response = await self._splice_m2m_rows( + response, + secondary_table, + secondary_table_pk, + m2m_name, + m2m_select, + ) + + return response + + async def _handle_m2m_postgres( + self, response: list, m2m_select: M2MSelect + ) -> list: + m2m_name = m2m_select.m2m._meta.name + secondary_table = m2m_select.m2m._meta.secondary_table + secondary_table_pk = secondary_table._meta.primary_key + + if m2m_select.as_list: + if ( + type(m2m_select.columns[0]) in (JSON, JSONB) + and m2m_select.load_json + ): + for row in response: + data = row[m2m_name] + row[m2m_name] = [load_json(i) for i in data] + elif m2m_select.serialisation_safe: + for row in response: + data = row[m2m_name] + row[m2m_name] = load_json(data) if data else [] + else: + response = await self._splice_m2m_rows( + response, + secondary_table, + secondary_table_pk, + m2m_name, + m2m_select, + ) + + return response + async def response_handler(self, response): m2m_selects = [ i @@ -302,101 +389,15 @@ async def response_handler(self, response): if isinstance(i, M2MSelect) ] for m2m_select in m2m_selects: - m2m_name = m2m_select.m2m._meta.name - secondary_table = m2m_select.m2m._meta.secondary_table - secondary_table_pk = secondary_table._meta.primary_key - if self.engine_type == "sqlite": - # With M2M queries in SQLite, we always get the value back as a - # list of strings, so we need to do some type conversion. - value_type = ( - m2m_select.columns[0].__class__.value_type - if m2m_select.as_list and m2m_select.serialisation_safe - else secondary_table_pk.value_type + response = await self._handle_m2m_sqlite( + response, m2m_select ) - try: - for row in response: - data = row[m2m_name] - row[m2m_name] = ( - [value_type(i) for i in row[m2m_name]] - if data - else [] - ) - except ValueError: - colored_warning( - "Unable to do type conversion for the " - f"{m2m_name} relation" - ) - - # If the user requested a single column, we just return that - # from the database. Otherwise we request the primary key - # value, so we can fetch the rest of the data in a subsequent - # SQL query - see below. - if m2m_select.as_list: - if m2m_select.serialisation_safe: - pass - else: - response = await self._splice_m2m_rows( - response, - secondary_table, - secondary_table_pk, - m2m_name, - m2m_select, - as_list=True, - ) - else: - if ( - len(m2m_select.columns) == 1 - and m2m_select.serialisation_safe - ): - column_name = m2m_select.columns[0]._meta.name - for row in response: - row[m2m_name] = [ - {column_name: i} for i in row[m2m_name] - ] - else: - response = await self._splice_m2m_rows( - response, - secondary_table, - secondary_table_pk, - m2m_name, - m2m_select, - ) - elif self.engine_type in ("postgres", "cockroach"): - if m2m_select.as_list: - # We get the data back as an array, and can just return it - # unless it's JSON. - if ( - type(m2m_select.columns[0]) in (JSON, JSONB) - and m2m_select.load_json - ): - for row in response: - data = row[m2m_name] - row[m2m_name] = [load_json(i) for i in data] - elif m2m_select.serialisation_safe: - # If the columns requested can be safely serialised, they - # are returned as a JSON string, so we need to deserialise - # it. - for row in response: - data = row[m2m_name] - row[m2m_name] = load_json(data) if data else [] - else: - # If the data can't be safely serialised as JSON, we get - # back an array of primary key values, and need to - # splice in the correct values using Python. - response = await self._splice_m2m_rows( - response, - secondary_table, - secondary_table_pk, - m2m_name, - m2m_select, - ) - - ####################################################################### + response = await self._handle_m2m_postgres( + response, m2m_select + ) - # If no columns were specified, it's a select *, so we know that - # no columns were selected from related tables. was_select_star = len(self.columns_delegate.selected_columns) == 0 if self.output_delegate._output.nested and not was_select_star: From 0138ffaa9dde37ffd5eeea088ee8948b4b810d1f Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 02:26:00 -0300 Subject: [PATCH 04/13] refactor: extract Select.default_querystrings clause handling into helper methods (R0912) --- piccolo/query/methods/select.py | 48 ++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 5c3ef6f7f..0fd25a906 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -571,7 +571,6 @@ def _check_valid_call_chain(self, keys: Sequence[Selectable]) -> bool: @property def default_querystrings(self) -> Sequence[QueryString]: - # JOIN self._check_valid_call_chain(self.columns_delegate.selected_columns) select_joins = self._get_joins(self.columns_delegate.selected_columns) @@ -583,21 +582,15 @@ def default_querystrings(self) -> Sequence[QueryString]: self.order_by_delegate.get_order_by_columns() ) - # Combine all joins, and remove duplicates joins: list[str] = list( OrderedDict.fromkeys( select_joins + where_joins + having_joins + order_by_joins ) ) - ####################################################################### - - # If no columns have been specified for selection, select all columns - # on the table: if len(self.columns_delegate.selected_columns) == 0: self.columns_delegate.selected_columns = self.table._meta.columns - # If secret fields need to be omitted, remove them from the list. if self.exclude_secrets: self.columns_delegate.remove_secret_columns() @@ -608,8 +601,6 @@ def default_querystrings(self) -> Sequence[QueryString]: for c in self.columns_delegate.selected_columns ] - ####################################################################### - args: list[Any] = [] query = "SELECT" @@ -627,6 +618,17 @@ def default_querystrings(self) -> Sequence[QueryString]: for join in joins: query += f" {join}" + query, args = self._append_clauses(query, args) + self._validate_sqlite_offset(engine_type) + query, args = self._append_lock_rows(query, args, engine_type) + + querystring = QueryString(query, *args) + + return [querystring] + + def _append_clauses( + self, query: str, args: list[Any] + ) -> tuple[str, list[Any]]: if self.as_of_delegate._as_of: query += "{}" args.append(self.as_of_delegate._as_of.querystring) @@ -647,6 +649,17 @@ def default_querystrings(self) -> Sequence[QueryString]: query += "{}" args.append(self.order_by_delegate._order_by.querystring) + if self.limit_delegate._limit: + query += "{}" + args.append(self.limit_delegate._limit.querystring) + + if self.offset_delegate._offset: + query += "{}" + args.append(self.offset_delegate._offset.querystring) + + return query, args + + def _validate_sqlite_offset(self, engine_type: str) -> None: if ( engine_type == "sqlite" and self.offset_delegate._offset @@ -657,27 +670,18 @@ def default_querystrings(self) -> Sequence[QueryString]: "SQLite." ) - if self.limit_delegate._limit: - query += "{}" - args.append(self.limit_delegate._limit.querystring) - - if self.offset_delegate._offset: - query += "{}" - args.append(self.offset_delegate._offset.querystring) - + def _append_lock_rows( + self, query: str, args: list[Any], engine_type: str + ) -> tuple[str, list[Any]]: if self.lock_rows_delegate._lock_rows: if engine_type == "sqlite": raise NotImplementedError( "SQLite doesn't support row locking e.g. SELECT ... FOR " "UPDATE" ) - query += "{}" args.append(self.lock_rows_delegate._lock_rows.querystring) - - querystring = QueryString(query, *args) - - return [querystring] + return query, args async def run( self, From e2b8009c27ba1252f9c0e6cc40b86f2392ecc121 Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 02:41:17 -0300 Subject: [PATCH 05/13] refactor: split Column.get_sql_value engine-specific logic into helpers (R0912) --- piccolo/columns/base.py | 85 ++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 879e4088f..263c2bb84 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -924,45 +924,60 @@ def get_sql_value( elif isinstance(value, bytes): return f"{delimiter}{value.hex()}{delimiter}" - # SQLite specific + # Engine-specific if self._meta.engine_type == "sqlite": - if adapter := sqlite_adapters.get(type(value)): - sqlite_value = adapter(value) - return ( - f"{delimiter}{sqlite_value}{delimiter}" - if isinstance(sqlite_value, str) - else sqlite_value - ) + return self._get_sqlite_sql_value(value, delimiter, sqlite_adapters) - # Postgres and Cockroach - if self._meta.engine_type in ["postgres", "cockroach"]: - if isinstance(value, datetime.datetime): - return f"{delimiter}{value.isoformat().replace('T', ' ')}{delimiter}" # noqa: E501 - elif isinstance(value, datetime.date): - return f"{delimiter}{value.isoformat()}{delimiter}" - elif isinstance(value, datetime.time): - return f"{delimiter}{value.isoformat()}{delimiter}" - elif isinstance(value, datetime.timedelta): - interval = IntervalCustom.from_timedelta(value) - return getattr(interval, self._meta.engine_type) - elif isinstance(value, uuid.UUID): - return f"{delimiter}{value}{delimiter}" - elif isinstance(value, list): - # Convert to the array syntax. - return ( - delimiter - + "{" - + ",".join( - self.get_sql_value( - i, - delimiter="" if isinstance(i, list) else '"', - ) - for i in value + if self._meta.engine_type in ("postgres", "cockroach"): + return self._get_postgres_sql_value(value, delimiter) + + return str(value) + + def _get_sqlite_sql_value( + self, + value: Any, + delimiter: str, + sqlite_adapters: dict, + ) -> str: + if adapter := sqlite_adapters.get(type(value)): + sqlite_value = adapter(value) + return ( + f"{delimiter}{sqlite_value}{delimiter}" + if isinstance(sqlite_value, str) + else sqlite_value + ) + return str(value) + + def _get_postgres_sql_value( + self, + value: Any, + delimiter: str, + ) -> str: + if isinstance(value, datetime.datetime): + return f"{delimiter}{value.isoformat().replace('T', ' ')}{delimiter}" + elif isinstance(value, datetime.date): + return f"{delimiter}{value.isoformat()}{delimiter}" + elif isinstance(value, datetime.time): + return f"{delimiter}{value.isoformat()}{delimiter}" + elif isinstance(value, datetime.timedelta): + interval = IntervalCustom.from_timedelta(value) + return getattr(interval, self._meta.engine_type) + elif isinstance(value, uuid.UUID): + return f"{delimiter}{value}{delimiter}" + elif isinstance(value, list): + return ( + delimiter + + "{" + + ",".join( + self.get_sql_value( + i, + delimiter="" if isinstance(i, list) else '"', ) - + "}" - + delimiter + for i in value ) - + + "}" + + delimiter + ) return str(value) @property From 76ac0081ac1d523bc3140e9adc6a547fa3fbcbfa Mon Sep 17 00:00:00 2001 From: Wania Santos Date: Tue, 30 Jun 2026 03:31:16 -0300 Subject: [PATCH 06/13] refactor: replace if/elif chain with dispatch table in get_column_default to reduce branches (R0912) --- piccolo/apps/schema/commands/generate.py | 143 ++++++++++++++--------- 1 file changed, 86 insertions(+), 57 deletions(-) diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index d4e94ef16..c94ececca 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -382,6 +382,88 @@ def __add__(self, value: OutputSchema) -> OutputSchema: } +def _handle_boolean(value: dict[str, str]) -> bool: + return value["value"] == "true" + + +def _handle_interval(value: dict[str, str]) -> IntervalCustom: + kwargs = {} + for period in [ + "years", + "months", + "weeks", + "days", + "hours", + "minutes", + "seconds", + ]: + period_match = value.get(period, 0) + if period_match: + kwargs[period] = int(period_match) + digits = value["digits"] + if digits: + kwargs.update( + dict( + zip( + ["hours", "minutes", "seconds"], + [int(v) for v in digits.split(":")], + ) + ) + ) + return IntervalCustom(**kwargs) + + +def _handle_json(value: dict[str, str]) -> Any: + return json.loads(value["value"]) + + +def _handle_uuid(value: dict[str, str]) -> Any: + return uuid.uuid4 + + +def _handle_date(value: dict[str, str]) -> Any: + return ( + date.today + if value["value"] == "CURRENT_DATE" + else defaults.date.DateCustom( + *[int(v) for v in value["value"].split("-")] + ) + ) + + +def _handle_bytea(value: dict[str, str]) -> bytes: + return value["value"].encode("utf8") + + +def _handle_timestamp(value: dict[str, str]) -> Any: + return ( + datetime.now + if value["value"] == "CURRENT_TIMESTAMP" + else datetime.fromtimestamp(float(value["value"])) + ) + + +def _handle_timestamptz(value: dict[str, str]) -> Any: + return ( + datetime.now + if value["value"] == "CURRENT_TIMESTAMP" + else datetime.fromtimestamp(float(value["value"])) + ) + + +_COLUMN_DEFAULT_HANDLERS: dict[type[Column], Any] = { + Boolean: _handle_boolean, + Interval: _handle_interval, + JSON: _handle_json, + JSONB: _handle_json, + UUID: _handle_uuid, + Date: _handle_date, + Bytea: _handle_bytea, + Timestamp: _handle_timestamp, + Timestamptz: _handle_timestamptz, +} + + def get_column_default( column_type: type[Column], column_default: str, engine_type: str ) -> Any: @@ -399,63 +481,10 @@ def get_column_default( match = re.match(pat, column_default) if match is not None: value = match.groupdict() - - if column_type is Boolean: - return value["value"] == "true" - elif column_type is Interval: - kwargs = {} - for period in [ - "years", - "months", - "weeks", - "days", - "hours", - "minutes", - "seconds", - ]: - period_match = value.get(period, 0) - if period_match: - kwargs[period] = int(period_match) - digits = value["digits"] - if digits: - kwargs.update( - dict( - zip( - ["hours", "minutes", "seconds"], - [int(v) for v in digits.split(":")], - ) - ) - ) - - return IntervalCustom(**kwargs) - elif column_type is JSON or column_type is JSONB: - return json.loads(value["value"]) - elif column_type is UUID: - return uuid.uuid4 - elif column_type is Date: - return ( - date.today - if value["value"] == "CURRENT_DATE" - else defaults.date.DateCustom( - *[int(v) for v in value["value"].split("-")] - ) - ) - elif column_type is Bytea: - return value["value"].encode("utf8") - elif column_type is Timestamp: - return ( - datetime.now - if value["value"] == "CURRENT_TIMESTAMP" - else datetime.fromtimestamp(float(value["value"])) - ) - elif column_type is Timestamptz: - return ( - datetime.now - if value["value"] == "CURRENT_TIMESTAMP" - else datetime.fromtimestamp(float(value["value"])) - ) - else: - return column_type.value_type(value["value"]) + handler = _COLUMN_DEFAULT_HANDLERS.get(column_type) + if handler is not None: + return handler(value) + return column_type.value_type(value["value"]) INDEX_METHOD_MAP: dict[str, IndexMethod] = { From d18cc97191d5b6e77b9aca6edc4a1600da148a3a Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 08:48:09 -0300 Subject: [PATCH 07/13] refactor: split create_pydantic_model field building into helpers (R0912) --- piccolo/utils/pydantic.py | 243 +++++++++++++++++++++++--------------- 1 file changed, 147 insertions(+), 96 deletions(-) diff --git a/piccolo/utils/pydantic.py b/piccolo/utils/pydantic.py index 794406f17..6eb4770c0 100644 --- a/piccolo/utils/pydantic.py +++ b/piccolo/utils/pydantic.py @@ -108,6 +108,133 @@ def get_pydantic_value_type(column: Column) -> type: return value_type +def _validate_pydantic_params( + table: type[Table], + exclude_columns: tuple[Column, ...], + include_columns: tuple[Column, ...], + recursion_depth: int, +) -> None: + if exclude_columns and include_columns: + raise ValueError( + "`include_columns` and `exclude_columns` can't be used at the " + "same time." + ) + + if recursion_depth == 0: + if exclude_columns: + if not validate_columns(columns=exclude_columns, table=table): + raise ValueError( + f"`exclude_columns` are invalid: {exclude_columns!r}" + ) + + if include_columns: + if not validate_columns(columns=include_columns, table=table): + raise ValueError( + f"`include_columns` are invalid: {include_columns!r}" + ) + + +def _get_value_type_for_column( + column: Column, + deserialize_json: bool, + is_optional: bool, + validators: dict[str, Callable], + column_name: str, +) -> type: + if isinstance(column, (JSON, JSONB)): + if deserialize_json: + return pydantic.Json + validator = partial( + pydantic_json_validator, required=not is_optional + ) + validators[ + f"{column_name}_is_json" + ] = pydantic.field_validator(column_name)( + validator # type: ignore + ) + return column.value_type + return get_pydantic_value_type(column=column) + + +def _build_extra_for_column( + column: Column, + model_name: str, + nested: Union[bool, tuple[ForeignKey, ...]], + include_readable: bool, + recursion_depth: int, + max_recursion_depth: int, + _type: type, + columns: dict[str, Any], + column_name: str, + include_columns: tuple[Column, ...], + exclude_columns: tuple[Column, ...], + include_default_columns: bool, + all_optional: bool, + deserialize_json: bool, +) -> tuple[JsonDict, type]: + extra: JsonDict = { + "help_text": column._meta.help_text, + "choices": column._meta.get_choices_dict(), + "secret": column._meta.secret, + "nullable": column._meta.null, + "unique": column._meta.unique, + } + + if isinstance(column, ForeignKey): + if recursion_depth < max_recursion_depth and ( + (nested is True) + or ( + isinstance(nested, tuple) + and any( + column._equals(i) + for i in itertools.chain( + nested, *[i._meta.call_chain for i in nested] + ) + ) + ) + ): + nested_model_name = f"{model_name}.{column._meta.name}" + _type = create_pydantic_model( + table=column._foreign_key_meta.resolved_references, + nested=nested, + include_columns=include_columns, + exclude_columns=exclude_columns, + include_default_columns=include_default_columns, + include_readable=include_readable, + all_optional=all_optional, + deserialize_json=deserialize_json, + recursion_depth=recursion_depth + 1, + max_recursion_depth=max_recursion_depth, + model_name=nested_model_name, + ) + + tablename = ( + column._foreign_key_meta.resolved_references._meta.tablename + ) + target_column = ( + column._foreign_key_meta.resolved_target_column._meta.name + ) + extra["foreign_key"] = { + "to": tablename, + "target_column": target_column, + } + + if include_readable: + columns[f"{column_name}_readable"] = (str, None) + else: + if isinstance(column, Text): + extra["widget"] = "text-area" + elif isinstance(column, (JSON, JSONB)): + extra["widget"] = "json" + elif isinstance(column, Timestamptz): + extra["widget"] = "timestamptz" + + if isinstance(column, Array): + extra["dimensions"] = column._get_dimensions() + + return extra, _type + + def create_pydantic_model( table: type[Table], nested: Union[bool, tuple[ForeignKey, ...]] = False, @@ -181,24 +308,7 @@ def create_pydantic_model( A Pydantic model. """ # noqa: E501 - if exclude_columns and include_columns: - raise ValueError( - "`include_columns` and `exclude_columns` can't be used at the " - "same time." - ) - - if recursion_depth == 0: - if exclude_columns: - if not validate_columns(columns=exclude_columns, table=table): - raise ValueError( - f"`exclude_columns` are invalid: {exclude_columns!r}" - ) - - if include_columns: - if not validate_columns(columns=include_columns, table=table): - raise ValueError( - f"`include_columns` are invalid: {include_columns!r}" - ) + _validate_pydantic_params(table, exclude_columns, include_columns, recursion_depth) ########################################################################### @@ -245,21 +355,9 @@ def create_pydantic_model( ####################################################################### # Work out the column type - if isinstance(column, (JSON, JSONB)): - if deserialize_json: - value_type = pydantic.Json - else: - value_type = column.value_type - validator = partial( - pydantic_json_validator, required=not is_optional - ) - validators[ - f"{column_name}_is_json" - ] = pydantic.field_validator(column_name)( - validator # type: ignore - ) - else: - value_type = get_pydantic_value_type(column=column) + value_type = _get_value_type_for_column( + column, deserialize_json, is_optional, validators, column_name + ) _type = Optional[value_type] if is_optional else value_type @@ -272,69 +370,22 @@ def create_pydantic_model( if column._meta.db_column_name != column._meta.name: params["alias"] = column._meta.db_column_name - extra: JsonDict = { - "help_text": column._meta.help_text, - "choices": column._meta.get_choices_dict(), - "secret": column._meta.secret, - "nullable": column._meta.null, - "unique": column._meta.unique, - } - - if isinstance(column, ForeignKey): - if recursion_depth < max_recursion_depth and ( - (nested is True) - or ( - isinstance(nested, tuple) - and any( - column._equals(i) - for i in itertools.chain( - nested, *[i._meta.call_chain for i in nested] - ) - ) - ) - ): - nested_model_name = f"{model_name}.{column._meta.name}" - _type = create_pydantic_model( - table=column._foreign_key_meta.resolved_references, - nested=nested, - include_columns=include_columns, - exclude_columns=exclude_columns, - include_default_columns=include_default_columns, - include_readable=include_readable, - all_optional=all_optional, - deserialize_json=deserialize_json, - recursion_depth=recursion_depth + 1, - max_recursion_depth=max_recursion_depth, - model_name=nested_model_name, - ) - - tablename = ( - column._foreign_key_meta.resolved_references._meta.tablename - ) - target_column = ( - column._foreign_key_meta.resolved_target_column._meta.name - ) - extra["foreign_key"] = { - "to": tablename, - "target_column": target_column, - } - - if include_readable: - columns[f"{column_name}_readable"] = (str, None) - else: - # This is used to tell Piccolo Admin that we want to display these - # values using a specific widget. - if isinstance(column, Text): - extra["widget"] = "text-area" - elif isinstance(column, (JSON, JSONB)): - extra["widget"] = "json" - elif isinstance(column, Timestamptz): - extra["widget"] = "timestamptz" - - # It is useful for Piccolo API and Piccolo Admin to easily know - # how many dimensions the array has. - if isinstance(column, Array): - extra["dimensions"] = column._get_dimensions() + extra, _type = _build_extra_for_column( + column, + model_name, + nested, + include_readable, + recursion_depth, + max_recursion_depth, + _type, + columns, + column_name, + include_columns, + exclude_columns, + include_default_columns, + all_optional, + deserialize_json, + ) field = pydantic.Field( json_schema_extra={"extra": extra}, From a01026fc1ef9007b2b05f738a571bd199b9c71a8 Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 08:59:42 -0300 Subject: [PATCH 08/13] refactor: split ASGI new command file generation into helpers (R0912) --- piccolo/apps/asgi/commands/new.py | 102 ++++++++++++++++-------------- 1 file changed, 54 insertions(+), 48 deletions(-) diff --git a/piccolo/apps/asgi/commands/new.py b/piccolo/apps/asgi/commands/new.py index f85b7b980..c8c9322a2 100644 --- a/piccolo/apps/asgi/commands/new.py +++ b/piccolo/apps/asgi/commands/new.py @@ -43,6 +43,58 @@ def get_server() -> str: return SERVERS[int(server)] +def _process_directory(output_dir_path: str, sub_dir_names: list[str]): + for sub_dir_name in sub_dir_names: + if sub_dir_name.startswith("_"): + continue + sub_dir_path = os.path.join(output_dir_path, sub_dir_name) + if not os.path.exists(sub_dir_path): + os.mkdir(sub_dir_path) + + +def _process_file( + dir_path: str, + output_dir_path: str, + file_name: str, + template_context: dict, +): + if file_name.startswith("_") and file_name != "__init__.py.jinja": + return + extension = file_name.rsplit(".")[0] + if extension in ("pyc",): + return + if file_name.endswith(".jinja"): + output_file_name = file_name.replace(".jinja", "") + template = Environment( + loader=FileSystemLoader(searchpath=dir_path) + ).get_template(file_name) + output_contents = template.render(**template_context) + if output_file_name.endswith(".py"): + try: + output_contents = black.format_str( + output_contents, + mode=black.FileMode(line_length=80), + ) + except Exception as exception: + print(f"Problem processing {output_file_name}") + raise exception from exception + with open( + os.path.join(output_dir_path, output_file_name), "w" + ) as f: + f.write(output_contents) + else: + if file_name.endswith(".jinja_raw"): + output_file_name = file_name.replace( + ".jinja_raw", ".jinja" + ) + else: + output_file_name = file_name + shutil.copy( + os.path.join(dir_path, file_name), + os.path.join(output_dir_path, output_file_name), + ) + + def new(root: str = ".", name: str = "piccolo_project"): """ Create a basic ASGI app, including Piccolo, routing, and an admin. @@ -77,56 +129,10 @@ def new(root: str = ".", name: str = "piccolo_project"): continue os.mkdir(dir_path) - for sub_dir_name in sub_dir_names: - if sub_dir_name.startswith("_"): - continue - - sub_dir_path = os.path.join(output_dir_path, sub_dir_name) - if not os.path.exists(sub_dir_path): - os.mkdir(sub_dir_path) + _process_directory(output_dir_path, sub_dir_names) for file_name in file_names: - if file_name.startswith("_") and file_name != "__init__.py.jinja": - continue - - extension = file_name.rsplit(".")[0] - if extension in ("pyc",): - continue - - if file_name.endswith(".jinja"): - output_file_name = file_name.replace(".jinja", "") - template = Environment( - loader=FileSystemLoader(searchpath=dir_path) - ).get_template(file_name) - - output_contents = template.render(**template_context) - - if output_file_name.endswith(".py"): - try: - output_contents = black.format_str( - output_contents, - mode=black.FileMode(line_length=80), - ) - except Exception as exception: - print(f"Problem processing {output_file_name}") - raise exception from exception - - with open( - os.path.join(output_dir_path, output_file_name), "w" - ) as f: - f.write(output_contents) - else: - if file_name.endswith(".jinja_raw"): - output_file_name = file_name.replace( - ".jinja_raw", ".jinja" - ) - else: - output_file_name = file_name - - shutil.copy( - os.path.join(dir_path, file_name), - os.path.join(output_dir_path, output_file_name), - ) + _process_file(dir_path, output_dir_path, file_name, template_context) print( "Run `pip install -r requirements.txt` and `python main.py` to get " From bb8edf87cced4af036533584f6c3df69b0cc6d0e Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 09:11:51 -0300 Subject: [PATCH 09/13] refactor: split create_table_class_from_db column processing into helpers (R0912) --- piccolo/apps/schema/commands/generate.py | 210 ++++++++++++++--------- 1 file changed, 130 insertions(+), 80 deletions(-) diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index c94ececca..fdb2843e9 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -683,6 +683,114 @@ async def get_foreign_key_reference( return ConstraintTable() +def _get_column_type( + data_type: str, engine_type: str +) -> Optional[type[Column]]: + if engine_type == "cockroach": + return COLUMN_TYPE_MAP_COCKROACH.get(data_type) + return COLUMN_TYPE_MAP.get(data_type) + + +def _apply_primary_key_override( + column_type: type[Column], +) -> type[Column]: + if column_type == Integer: + return Serial + if column_type == BigInt: + return Serial + return column_type + + +def _apply_index_params( + indexes: TableIndexes, column_name: str, kwargs: dict[str, Any] +) -> None: + index = indexes.get_column_index(column_name=column_name) + if index is not None: + kwargs["index"] = True + kwargs["index_method"] = index.method + + +def _apply_type_specific_params( + column_type: type[Column], pg_row_meta: RowMeta +) -> dict[str, Any]: + params: dict[str, Any] = {} + if column_type is Varchar: + params["length"] = pg_row_meta.character_maximum_length + elif isinstance(column_type, Numeric): + radix = pg_row_meta.numeric_precision_radix + if radix: + precision = int(str(pg_row_meta.numeric_precision), radix) + scale = int(str(pg_row_meta.numeric_scale), radix) + params["digits"] = (precision, scale) + else: + params["digits"] = None + return params + + +async def _resolve_foreign_key( + column_name: str, + constraints: TableConstraints, + triggers: TableTriggers, + table_class: type[Table], + tablename: str, + engine_type: str, + output_schema: OutputSchema, + kwargs: dict[str, Any], +) -> tuple[type[Column], dict[str, Any], OutputSchema]: + fk_constraint_table = constraints.get_foreign_key_constraint_name( + column_name=column_name + ) + column_type: type[Column] = ForeignKey + constraint_table = await get_foreign_key_reference( + table_class=table_class, + constraint_name=fk_constraint_table.name, + constraint_schema=fk_constraint_table.schema, + ) + if constraint_table.name: + if constraint_table.name == tablename: + referenced_output_schema = output_schema + referenced_table: Union[str, Optional[type[Table]]] = "self" + else: + referenced_output_schema = ( + await create_table_class_from_db( + table_class=table_class, + tablename=constraint_table.name, + schema_name=constraint_table.schema, + engine_type=engine_type, + ) + ) + referenced_table = ( + referenced_output_schema.get_table_with_name( + tablename=constraint_table.name + ) + ) + + kwargs["references"] = ( + referenced_table + if referenced_table is not None + else ForeignKeyPlaceholder + ) + + trigger = triggers.get_column_ref_trigger( + column_name, constraint_table.name + ) + if trigger: + kwargs["on_update"] = OnUpdate(trigger.on_update) + kwargs["on_delete"] = OnDelete(trigger.on_delete) + else: + output_schema.trigger_warnings.append( + f"{tablename}.{column_name}" + ) + + output_schema = sum( # type: ignore + [output_schema, referenced_output_schema] # type: ignore + ) # type: ignore + else: + kwargs["references"] = ForeignKeyPlaceholder + + return column_type, kwargs, output_schema + + async def create_table_class_from_db( table_class: type[Table], tablename: str, @@ -709,18 +817,15 @@ async def create_table_class_from_db( columns: dict[str, Column] = {} for pg_row_meta in table_schema: - data_type = pg_row_meta.data_type - - if engine_type == "cockroach": - column_type = COLUMN_TYPE_MAP_COCKROACH.get(data_type, None) - else: - column_type = COLUMN_TYPE_MAP.get(data_type, None) - column_name = pg_row_meta.column_name - column_default = pg_row_meta.column_default + + column_type = _get_column_type( + data_type=pg_row_meta.data_type, + engine_type=engine_type, + ) if not column_type: output_schema.warnings.append( - f"{tablename}.{column_name} ['{data_type}']" + f"{tablename}.{column_name} ['{pg_row_meta.data_type}']" ) column_type = Column @@ -729,91 +834,36 @@ async def create_table_class_from_db( "unique": constraints.is_unique(column_name=column_name), } - index = indexes.get_column_index(column_name=column_name) - if index is not None: - kwargs["index"] = True - kwargs["index_method"] = index.method + _apply_index_params(indexes, column_name, kwargs) if constraints.is_primary_key(column_name=column_name): kwargs["primary_key"] = True - if column_type == Integer: - column_type = Serial - if column_type == BigInt: - column_type = Serial - # column_type = BigSerial + column_type = _apply_primary_key_override(column_type) if constraints.is_foreign_key(column_name=column_name): - fk_constraint_table = constraints.get_foreign_key_constraint_name( - column_name=column_name - ) - column_type = ForeignKey - constraint_table = await get_foreign_key_reference( + column_type, kwargs, output_schema = await _resolve_foreign_key( + column_name=column_name, + constraints=constraints, + triggers=triggers, table_class=table_class, - constraint_name=fk_constraint_table.name, - constraint_schema=fk_constraint_table.schema, + tablename=tablename, + engine_type=engine_type, + output_schema=output_schema, + kwargs=kwargs, ) - if constraint_table.name: - referenced_table: Union[str, Optional[type[Table]]] - - if constraint_table.name == tablename: - referenced_output_schema = output_schema - referenced_table = "self" - else: - referenced_output_schema = ( - await create_table_class_from_db( - table_class=table_class, - tablename=constraint_table.name, - schema_name=constraint_table.schema, - engine_type=engine_type, - ) - ) - referenced_table = ( - referenced_output_schema.get_table_with_name( - tablename=constraint_table.name - ) - ) - kwargs["references"] = ( - referenced_table - if referenced_table is not None - else ForeignKeyPlaceholder - ) - - trigger = triggers.get_column_ref_trigger( - column_name, constraint_table.name - ) - if trigger: - kwargs["on_update"] = OnUpdate(trigger.on_update) - kwargs["on_delete"] = OnDelete(trigger.on_delete) - else: - output_schema.trigger_warnings.append( - f"{tablename}.{column_name}" - ) - - output_schema = sum( # type: ignore - [output_schema, referenced_output_schema] # type: ignore - ) # type: ignore - else: - kwargs["references"] = ForeignKeyPlaceholder output_schema.imports.append( "from piccolo.columns.column_types import " + column_type.__name__ # type: ignore ) - if column_type is Varchar: - kwargs["length"] = pg_row_meta.character_maximum_length - elif isinstance(column_type, Numeric): - radix = pg_row_meta.numeric_precision_radix - if radix: - precision = int(str(pg_row_meta.numeric_precision), radix) - scale = int(str(pg_row_meta.numeric_scale), radix) - kwargs["digits"] = (precision, scale) - else: - kwargs["digits"] = None - - if column_default: + kwargs.update( + _apply_type_specific_params(column_type, pg_row_meta) + ) + + if pg_row_meta.column_default: default_value = get_column_default( - column_type, column_default, engine_type + column_type, pg_row_meta.column_default, engine_type ) if default_value: kwargs["default"] = default_value From d71230ffb2135e6cf777955e08cad69fccf2f65f Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 09:19:43 -0300 Subject: [PATCH 10/13] refactor: split ForwardsMigrationManager migration flow into helpers (R0912) --- piccolo/apps/migrations/commands/forwards.py | 97 +++++++++++--------- 1 file changed, 55 insertions(+), 42 deletions(-) diff --git a/piccolo/apps/migrations/commands/forwards.py b/piccolo/apps/migrations/commands/forwards.py index adc7e657d..6a284ec37 100644 --- a/piccolo/apps/migrations/commands/forwards.py +++ b/piccolo/apps/migrations/commands/forwards.py @@ -26,6 +26,42 @@ def __init__( self.preview = preview super().__init__() + def _get_migration_subset( + self, havent_run: list[str] + ) -> list[str] | None: + if self.migration_id == "all": + return havent_run + elif self.migration_id == "1": + return havent_run[:1] + else: + try: + index = havent_run.index(self.migration_id) + except ValueError: + return None + else: + return havent_run[: index + 1] + + async def _run_single_migration( + self, + _id: str, + response: object, + app_name: str, + ) -> None: + if isinstance(response, MigrationManager): + if self.fake or response.fake: + print(f"- {_id}: faked! ⏭️") + else: + if self.preview: + response.preview = True + await response.run() + + print("ok! ✔️") + + if not self.preview: + await Migration.insert().add( + Migration(name=_id, app_name=app_name) + ).run() + async def run_migrations(self, app_config: AppConfig) -> MigrationResult: already_ran = await Migration.get_migrations_which_ran( app_name=app_config.app_name @@ -42,52 +78,29 @@ async def run_migrations(self, app_config: AppConfig) -> MigrationResult: print(f"👍 {n} migration{'s' if n != 1 else ''} already complete") havent_run = sorted(set(ids) - set(already_ran)) - if len(havent_run) == 0: - # Make sure this still appears successful, as we don't want this - # to appear as an error in automated scripts. + if not havent_run: message = "🏁 No migrations need to be run" print(message) return MigrationResult(success=True, message=message) - else: - n = len(havent_run) - print(f"⏩ {n} migration{'s' if n != 1 else ''} not yet run") - if self.migration_id == "all": - subset = havent_run - elif self.migration_id == "1": - subset = havent_run[:1] - else: - try: - index = havent_run.index(self.migration_id) - except ValueError: - message = f"{self.migration_id} is unrecognised" - print(message, file=sys.stderr) - return MigrationResult(success=False, message=message) - else: - subset = havent_run[: index + 1] - - if subset: - n = len(subset) - print(f"🚀 Running {n} migration{'s' if n != 1 else ''}:") - - for _id in subset: - migration_module = migration_modules[_id] - response = await migration_module.forwards() - - if isinstance(response, MigrationManager): - if self.fake or response.fake: - print(f"- {_id}: faked! ⏭️") - else: - if self.preview: - response.preview = True - await response.run() - - print("ok! ✔️") - - if not self.preview: - await Migration.insert().add( - Migration(name=_id, app_name=app_config.app_name) - ).run() + n = len(havent_run) + print(f"⏩ {n} migration{'s' if n != 1 else ''} not yet run") + + subset = self._get_migration_subset(havent_run) + if subset is None: + message = f"{self.migration_id} is unrecognised" + print(message, file=sys.stderr) + return MigrationResult(success=False, message=message) + + n = len(subset) + print(f"🚀 Running {n} migration{'s' if n != 1 else ''}:") + + for _id in subset: + migration_module = migration_modules[_id] + response = await migration_module.forwards() + await self._run_single_migration( + _id, response, app_config.app_name + ) return MigrationResult(success=True, message="migration succeeded") From 5c631114056a48074a55f1c80da77fffc1d45920 Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 09:26:59 -0300 Subject: [PATCH 11/13] refactor: split SchemaSnapshot.get_snapshot operations into helpers (R0912) --- .../apps/migrations/auto/schema_snapshot.py | 174 ++++++++++-------- 1 file changed, 99 insertions(+), 75 deletions(-) diff --git a/piccolo/apps/migrations/auto/schema_snapshot.py b/piccolo/apps/migrations/auto/schema_snapshot.py index 5bf343063..d5549c28f 100644 --- a/piccolo/apps/migrations/auto/schema_snapshot.py +++ b/piccolo/apps/migrations/auto/schema_snapshot.py @@ -30,85 +30,109 @@ def get_table_from_snapshot(self, table_class_name: str) -> DiffableTable: def get_snapshot(self) -> list[DiffableTable]: tables: list[DiffableTable] = [] - # Make sure the managers are sorted correctly: sorted_managers = sorted(self.managers, key=lambda x: x.migration_id) for manager in sorted_managers: - for table in manager.add_tables: - tables.append(table) - - for drop_table in manager.drop_tables: - tables = [ - i for i in tables if i.class_name != drop_table.class_name - ] - - for rename_table in manager.rename_tables: - for table in tables: - if table.class_name == rename_table.old_class_name: - table.class_name = rename_table.new_class_name - table.tablename = rename_table.new_tablename - break - - for change_table_schema in manager.change_table_schemas: - for table in tables: - if table.tablename == change_table_schema.tablename: - table.schema = change_table_schema.new_schema - break + self._apply_add_tables(tables, manager) + self._apply_drop_tables(tables, manager) + self._apply_rename_tables(tables, manager) + self._apply_change_table_schemas(tables, manager) for table in tables: - add_columns = manager.add_columns.columns_for_table_class_name( - table.class_name - ) - table.columns.extend(add_columns) - - ############################################################### - - drop_columns = manager.drop_columns.for_table_class_name( - table.class_name - ) - for drop_column in drop_columns: - table.columns = [ - i - for i in table.columns - if i._meta.name != drop_column.column_name - ] - - ############################################################### - - alter_columns = manager.alter_columns.for_table_class_name( - table.class_name - ) - for alter_column in alter_columns: - for index, column in enumerate(table.columns): - if column._meta.name == alter_column.column_name: - for key, value in alter_column.params.items(): - setattr(column._meta, key, value) - column._meta.params.update({key: value}) - - # If the column type has changed, we need to update - # it. - if ( - alter_column.column_class - != alter_column.old_column_class - ) and alter_column.column_class is not None: - new_column = alter_column.column_class( - **column._meta.params - ) - new_column._meta = column._meta - table.columns[index] = new_column - - ############################################################### - - for ( - rename_column - ) in manager.rename_columns.for_table_class_name( - table.class_name - ): - for column in table.columns: - if column._meta.name == rename_column.old_column_name: - column._meta.name = rename_column.new_column_name - column._meta.db_column_name = ( - rename_column.new_db_column_name - ) + self._apply_add_columns(table, manager) + self._apply_drop_columns(table, manager) + self._apply_alter_columns(table, manager) + self._apply_rename_columns(table, manager) return tables + + ########################################################################### + + def _apply_add_tables( + self, tables: list[DiffableTable], manager: MigrationManager + ): + for table in manager.add_tables: + tables.append(table) + + def _apply_drop_tables( + self, tables: list[DiffableTable], manager: MigrationManager + ): + for drop_table in manager.drop_tables: + tables[:] = [ + i for i in tables if i.class_name != drop_table.class_name + ] + + def _apply_rename_tables( + self, tables: list[DiffableTable], manager: MigrationManager + ): + for rename_table in manager.rename_tables: + for table in tables: + if table.class_name == rename_table.old_class_name: + table.class_name = rename_table.new_class_name + table.tablename = rename_table.new_tablename + break + + def _apply_change_table_schemas( + self, tables: list[DiffableTable], manager: MigrationManager + ): + for change_table_schema in manager.change_table_schemas: + for table in tables: + if table.tablename == change_table_schema.tablename: + table.schema = change_table_schema.new_schema + break + + def _apply_add_columns( + self, table: DiffableTable, manager: MigrationManager + ): + add_columns = manager.add_columns.columns_for_table_class_name( + table.class_name + ) + table.columns.extend(add_columns) + + def _apply_drop_columns( + self, table: DiffableTable, manager: MigrationManager + ): + drop_columns = manager.drop_columns.for_table_class_name( + table.class_name + ) + for drop_column in drop_columns: + table.columns = [ + i + for i in table.columns + if i._meta.name != drop_column.column_name + ] + + def _apply_alter_columns( + self, table: DiffableTable, manager: MigrationManager + ): + alter_columns = manager.alter_columns.for_table_class_name( + table.class_name + ) + for alter_column in alter_columns: + for index, column in enumerate(table.columns): + if column._meta.name == alter_column.column_name: + for key, value in alter_column.params.items(): + setattr(column._meta, key, value) + column._meta.params.update({key: value}) + if ( + alter_column.column_class + != alter_column.old_column_class + ) and alter_column.column_class is not None: + new_column = alter_column.column_class( + **column._meta.params + ) + new_column._meta = column._meta + table.columns[index] = new_column + + def _apply_rename_columns( + self, table: DiffableTable, manager: MigrationManager + ): + for rename_column in manager.rename_columns.for_table_class_name( + table.class_name + ): + for column in table.columns: + if column._meta.name == rename_column.old_column_name: + column._meta.name = rename_column.new_column_name + column._meta.db_column_name = ( + rename_column.new_db_column_name + ) From 8f8fc41513959ccd97fe8a2ac60ab8e1738dac53 Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 09:40:03 -0300 Subject: [PATCH 12/13] refactor: split serialise_params type handling into handlers (R0912) --- piccolo/apps/migrations/auto/serialisation.py | 450 +++++++++++------- 1 file changed, 268 insertions(+), 182 deletions(-) diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index e3d166353..81aa35157 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -522,61 +522,51 @@ def __repr__(self): ############################################################################### -def serialise_params( - params: dict[str, Any], inline_enums: bool = True -) -> SerialisedParams: - """ - When writing column params to a migration file, or outputting to the - playground, we need to serialise some of the values. - - :param inline_enums: - If ``True``, enum value are inlined, for example:: - - value=Enum('MyEnum', {'some_value': 'some_value'})) - - Otherwise, it is reproduced as:: - - value=MyEnum - - And the enum definition is added to - ``SerialisedParams.extra_definitions``. - - """ - params = deepcopy(params) - extra_imports: list[Import] = [] - extra_definitions: list[Definition] = [] +@dataclass +class _SerialisationResult: + value: Any + imports: list[Import] + definitions: list[Definition] - for key, value in params.items(): - # Builtins, such as str, list and dict. - if inspect.getmodule(value) == builtins: - params[key] = SerialisedBuiltin(builtin=value) - continue - - # Column instances - if isinstance(value, Column): - # For target_column (which is used by ForeignKey), we can just - # serialise it as the column name: - if key == "target_column": - params[key] = value._meta.name - continue - ################################################################### +_Handler = Callable[[str, Any, bool], _SerialisationResult | None] - # For Array definitions, we want to serialise the full column - # definition: - column: Column = value - serialised_params: SerialisedParams = serialise_params( - params=column._meta.params +def _handle_builtin( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if inspect.getmodule(value) == builtins: + return _SerialisationResult( + value=SerialisedBuiltin(builtin=value), + imports=[], + definitions=[], + ) + return None + + +def _handle_column( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, Column): + if key == "target_column": + return _SerialisationResult( + value=value._meta.name, + imports=[], + definitions=[], ) - # Include the extra imports and definitions required for the - # column params. - extra_imports.extend(serialised_params.extra_imports) - extra_definitions.extend(serialised_params.extra_definitions) + column: Column = value + serialised_params: SerialisedParams = serialise_params( + params=column._meta.params + ) - column_class_name = column.__class__.__name__ - extra_imports.append( + column_class_name = column.__class__.__name__ + return _SerialisationResult( + value=SerialisedColumnInstance( + instance=value, serialised_params=serialised_params + ), + imports=[ + *serialised_params.extra_imports, Import( module=column.__class__.__module__, target=column_class_name, @@ -585,149 +575,185 @@ def serialise_params( f"COLUMN_{column_class_name.upper()}", None, ), - ) - ) - params[key] = SerialisedColumnInstance( - instance=value, serialised_params=serialised_params - ) - continue + ), + ], + definitions=list(serialised_params.extra_definitions), + ) + return None + - # Class instances - if isinstance(value, Default): - params[key] = SerialisedClassInstance(instance=value) - extra_imports.append( +def _handle_default( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, Default): + return _SerialisationResult( + value=SerialisedClassInstance(instance=value), + imports=[ Import( module=value.__class__.__module__, target=value.__class__.__name__, expect_conflict_with_global_name=UniqueGlobalNames.DEFAULT, ) - ) - continue + ], + definitions=[], + ) + return None - # Dates and times - if isinstance( - value, (datetime.time, datetime.datetime, datetime.date) - ): - # Already has a good __repr__. - extra_imports.append( + +def _handle_datetime( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, (datetime.time, datetime.datetime, datetime.date)): + return _SerialisationResult( + value=value, + imports=[ Import( module=value.__class__.__module__, target=value.__class__.__name__, ) - ) - continue + ], + definitions=[], + ) + return None - # UUIDs - if isinstance(value, uuid.UUID): - params[key] = SerialisedUUID(instance=value) - extra_imports.append( + +def _handle_uuid( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, uuid.UUID): + return _SerialisationResult( + value=SerialisedUUID(instance=value), + imports=[ Import( module=UniqueGlobalNames.EXTERNAL_MODULE_UUID, expect_conflict_with_global_name=( UniqueGlobalNames.EXTERNAL_MODULE_UUID ), ) - ) - continue + ], + definitions=[], + ) + return None - # Decimals - if isinstance(value, decimal.Decimal): - params[key] = SerialisedDecimal(instance=value) - extra_imports.append( + +def _handle_decimal( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, decimal.Decimal): + return _SerialisationResult( + value=SerialisedDecimal(instance=value), + imports=[ Import( module=UniqueGlobalNames.STD_LIB_MODULE_DECIMAL, expect_conflict_with_global_name=( UniqueGlobalNames.STD_LIB_MODULE_DECIMAL ), ) - ) - continue - - # Enum instances - if isinstance(value, Enum): - if value.__module__.startswith("piccolo"): - # It's an Enum defined within Piccolo, so we can safely import - # it. - params[key] = SerialisedEnumInstance(instance=value) - extra_imports.append( + ], + definitions=[], + ) + return None + + +def _handle_enum_instance( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, Enum): + if value.__module__.startswith("piccolo"): + return _SerialisationResult( + value=SerialisedEnumInstance(instance=value), + imports=[ Import( module=value.__module__, target=value.__class__.__name__, ) + ], + definitions=[], + ) + else: + enum_serialised_params: SerialisedParams = serialise_params( + params={key: value.value} + ) + return _SerialisationResult( + value=enum_serialised_params.params[key], + imports=list(enum_serialised_params.extra_imports), + definitions=list(enum_serialised_params.extra_definitions), + ) + return None + + +def _handle_enum_type( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if inspect.isclass(value) and issubclass(value, Enum): + extra_imports: list[Import] = [ + Import( + module="enum", + target=UniqueGlobalNames.STD_LIB_ENUM, + expect_conflict_with_global_name=( + UniqueGlobalNames.STD_LIB_ENUM + ), + ) + ] + for member in value: + type_ = type(member.value) + module = inspect.getmodule(type_) + if module and module != builtins: + module_name = module.__name__ + extra_imports.append( + Import(module=module_name, target=type_.__name__) ) - else: - # It's a user defined Enum, so we'll insert the raw value. - enum_serialised_params: SerialisedParams = serialise_params( - params={key: value.value} - ) - params[key] = enum_serialised_params.params[key] - extra_imports.extend(enum_serialised_params.extra_imports) - extra_definitions.extend( - enum_serialised_params.extra_definitions - ) - - continue - # Enum types - if inspect.isclass(value) and issubclass(value, Enum): - extra_imports.append( - Import( - module="enum", - target=UniqueGlobalNames.STD_LIB_ENUM, - expect_conflict_with_global_name=( - UniqueGlobalNames.STD_LIB_ENUM - ), - ) + if inline_enums: + return _SerialisationResult( + value=InlineSerialisedEnumType(enum_type=value), + imports=extra_imports, + definitions=[], ) - for member in value: - type_ = type(member.value) - module = inspect.getmodule(type_) - - if module and module != builtins: - module_name = module.__name__ - extra_imports.append( - Import(module=module_name, target=type_.__name__) - ) - - if inline_enums: - params[key] = InlineSerialisedEnumType(enum_type=value) - else: - params[key] = SerialisedReference(name=value.__name__) - extra_definitions.append( + else: + return _SerialisationResult( + value=SerialisedReference(name=value.__name__), + imports=extra_imports, + definitions=[ SerialisedEnumTypeDefinition(enum_type=value) - ) + ], + ) + return None - # Functions - if inspect.isfunction(value): - if value.__name__ == "": - raise ValueError("Lambdas can't be serialised") - params[key] = SerialisedCallable(callable_=value) - extra_imports.append( - Import(module=value.__module__, target=value.__name__) - ) - continue - - # Lazy imports - we need to resolve these now, in case the target - # table class gets deleted in the future. - if isinstance(value, LazyTableReference): - table_type = value.resolve() - params[key] = SerialisedCallable(callable_=table_type) - extra_definitions.append( - SerialisedTableType(table_type=table_type) - ) - extra_imports.append( +def _handle_function( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if inspect.isfunction(value): + if value.__name__ == "": + raise ValueError("Lambdas can't be serialised") + + return _SerialisationResult( + value=SerialisedCallable(callable_=value), + imports=[ + Import( + module=value.__module__, target=value.__name__ + ) + ], + definitions=[], + ) + return None + + +def _handle_lazy_table_ref( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if isinstance(value, LazyTableReference): + table_type = value.resolve() + primary_key_class = table_type._meta.primary_key.__class__ + return _SerialisationResult( + value=SerialisedCallable(callable_=table_type), + imports=[ Import( module=Table.__module__, target=UniqueGlobalNames.TABLE, expect_conflict_with_global_name=UniqueGlobalNames.TABLE, - ) - ) - # also add missing primary key to extra_imports when creating a - # migration with a ForeignKey that uses a LazyTableReference - # https://github.com/piccolo-orm/piccolo/issues/865 - primary_key_class = table_type._meta.primary_key.__class__ - extra_imports.append( + ), Import( module=primary_key_class.__module__, target=primary_key_class.__name__, @@ -736,24 +762,31 @@ def serialise_params( f"COLUMN_{primary_key_class.__name__.upper()}", None, ), - ) - ) - continue + ), + ], + definitions=[ + SerialisedTableType(table_type=table_type) + ], + ) + return None + - # Replace any Table class values into class and table names - if inspect.isclass(value) and issubclass(value, Table): - params[key] = SerialisedCallable(callable_=value) - extra_definitions.append(SerialisedTableType(table_type=value)) - extra_imports.append( +def _handle_table_class( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if inspect.isclass(value) and issubclass(value, Table): + primary_key_class = value._meta.primary_key.__class__ + pk_serialised_params: SerialisedParams = serialise_params( + params=value._meta.primary_key._meta.params + ) + return _SerialisationResult( + value=SerialisedCallable(callable_=value), + imports=[ Import( module=Table.__module__, target=UniqueGlobalNames.TABLE, expect_conflict_with_global_name=UniqueGlobalNames.TABLE, - ) - ) - - primary_key_class = value._meta.primary_key.__class__ - extra_imports.append( + ), Import( module=primary_key_class.__module__, target=primary_key_class.__name__, @@ -762,28 +795,81 @@ def serialise_params( f"COLUMN_{primary_key_class.__name__.upper()}", None, ), + ), + *pk_serialised_params.extra_imports, + ], + definitions=[ + SerialisedTableType(table_type=value), + *pk_serialised_params.extra_definitions, + ], + ) + return None + + +def _handle_plain_class( + key: str, value: Any, inline_enums: bool = True +) -> _SerialisationResult | None: + if inspect.isclass(value) and not issubclass(value, Enum): + return _SerialisationResult( + value=SerialisedCallable(callable_=value), + imports=[ + Import( + module=value.__module__, target=value.__name__ ) - ) - # Include the extra imports and definitions required for the - # primary column params. - pk_serialised_params: SerialisedParams = serialise_params( - params=value._meta.primary_key._meta.params - ) - extra_imports.extend(pk_serialised_params.extra_imports) - extra_definitions.extend(pk_serialised_params.extra_definitions) + ], + definitions=[], + ) + return None - continue - # Plain class type - if inspect.isclass(value) and not issubclass(value, Enum): - params[key] = SerialisedCallable(callable_=value) - extra_imports.append( - Import(module=value.__module__, target=value.__name__) - ) - continue +_HANDLERS: list[_Handler] = [ + _handle_builtin, + _handle_column, + _handle_default, + _handle_datetime, + _handle_uuid, + _handle_decimal, + _handle_enum_instance, + _handle_enum_type, + _handle_function, + _handle_lazy_table_ref, + _handle_table_class, + _handle_plain_class, +] + + +def serialise_params( + params: dict[str, Any], inline_enums: bool = True +) -> SerialisedParams: + """ + When writing column params to a migration file, or outputting to the + playground, we need to serialise some of the values. + + :param inline_enums: + If ``True``, enum value are inlined, for example:: + + value=Enum('MyEnum', {'some_value': 'some_value'})) + + Otherwise, it is reproduced as:: - # All other types can remain as is. + value=MyEnum + And the enum definition is added to + ``SerialisedParams.extra_definitions``. + + """ + params = deepcopy(params) + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] + + for key, value in params.items(): + for handler in _HANDLERS: + result = handler(key, value, inline_enums=inline_enums) + if result is not None: + params[key] = result.value + extra_imports.extend(result.imports) + extra_definitions.extend(result.definitions) + break unique_extra_imports = list(set(extra_imports)) UniqueGlobalNames.warn_if_are_conflicting_objects(unique_extra_imports) From f0f2d12cab6c4ed9ae22f1ecdef759bc3e7390dd Mon Sep 17 00:00:00 2001 From: Graziella Lima Date: Tue, 30 Jun 2026 09:48:52 -0300 Subject: [PATCH 13/13] refactor: split MigrationManager._run_alter_columns operations into helpers (R0912) --- .../apps/migrations/auto/migration_manager.py | 449 ++++++++++-------- 1 file changed, 254 insertions(+), 195 deletions(-) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 57b97fa2d..3744e51f1 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -426,6 +426,26 @@ async def _run_query(self, query: Union[DDL, Query, SchemaDDLBase]): else: await query.run() + ########################################################################### + + @staticmethod + def _get_alter_params( + alter_column: AlterColumn, backwards: bool + ) -> tuple[dict[str, Any], dict[str, Any]]: + if backwards: + return alter_column.old_params, alter_column.params + return alter_column.params, alter_column.old_params + + @staticmethod + def _get_alter_classes( + alter_column: AlterColumn, backwards: bool + ) -> tuple[type[Column], type[Column]]: + if backwards: + return alter_column.old_column_class, alter_column.column_class + return alter_column.column_class, alter_column.old_column_class + + ########################################################################### + async def _run_alter_columns(self, backwards: bool = False): for table_class_name in self.alter_columns.table_class_names: alter_columns = self.alter_columns.for_table_class_name( @@ -444,225 +464,264 @@ async def _run_alter_columns(self, backwards: bool = False): ) for alter_column in alter_columns: - params = ( - alter_column.old_params - if backwards - else alter_column.params + params, old_params = self._get_alter_params( + alter_column, backwards ) - - old_params = ( - alter_column.params - if backwards - else alter_column.old_params + column_class, old_column_class = self._get_alter_classes( + alter_column, backwards ) - ############################################################### - - # Change the column type if possible - column_class = ( - alter_column.old_column_class - if backwards - else alter_column.column_class + await self._run_alter_column_type( + alter_column=alter_column, + params=params, + old_params=old_params, + column_class=column_class, + old_column_class=old_column_class, + _Table=_Table, ) - old_column_class = ( - alter_column.column_class - if backwards - else alter_column.old_column_class + await self._run_alter_fk_constraints( + alter_column=alter_column, + params=params, + table_class_name=table_class_name, + _Table=_Table, + ) + await self._run_alter_nullable( + alter_column=alter_column, params=params, _Table=_Table + ) + await self._run_alter_length( + alter_column=alter_column, params=params, _Table=_Table + ) + await self._run_alter_unique( + alter_column=alter_column, params=params, _Table=_Table + ) + await self._run_alter_index( + alter_column=alter_column, params=params, _Table=_Table + ) + await self._run_alter_default( + alter_column=alter_column, params=params, _Table=_Table + ) + await self._run_alter_digits( + alter_column=alter_column, params=params, _Table=_Table ) - if (old_column_class is not None) and ( - column_class is not None - ): - if old_column_class != column_class: - old_column = old_column_class(**old_params) - old_column._meta._table = _Table - old_column._meta._name = alter_column.column_name - old_column._meta.db_column_name = ( - alter_column.db_column_name - ) + ########################################################################### + # Helper methods for _run_alter_columns + ########################################################################### - new_column = column_class(**params) - new_column._meta._table = _Table - new_column._meta._name = alter_column.column_name - new_column._meta.db_column_name = ( - alter_column.db_column_name + async def _run_alter_column_type( + self, + alter_column: AlterColumn, + params: dict[str, Any], + old_params: dict[str, Any], + column_class: type[Column], + old_column_class: type[Column], + _Table: type[Table], + ): + if (old_column_class is not None) and (column_class is not None): + if old_column_class != column_class: + old_column = old_column_class(**old_params) + old_column._meta._table = _Table + old_column._meta._name = alter_column.column_name + old_column._meta.db_column_name = alter_column.db_column_name + + new_column = column_class(**params) + new_column._meta._table = _Table + new_column._meta._name = alter_column.column_name + new_column._meta.db_column_name = alter_column.db_column_name + + using_expression: Optional[str] = None + + if new_column.value_type != old_column.value_type: + if old_params.get("default", ...) is not None: + await self._run_query( + _Table.alter().drop_default(old_column) ) - using_expression: Optional[str] = None - - # Postgres won't automatically cast some types to - # others. We may as well try, as it will definitely - # fail otherwise. - if new_column.value_type != old_column.value_type: - if old_params.get("default", ...) is not None: - # Unless the column's default value is also - # something which can be cast to the new type, - # it will also fail. Drop the default value for - # now - the proper default is set later on. - await self._run_query( - _Table.alter().drop_default(old_column) - ) - - using_expression = "{}::{}".format( - alter_column.db_column_name, - new_column.column_type, - ) - - # We can't migrate a SERIAL to a BIGSERIAL or vice - # versa, as SERIAL isn't a true type, just an alias to - # other commands. - if issubclass(column_class, Serial) and issubclass( - old_column_class, Serial - ): - colored_warning( - "Unable to migrate Serial to BigSerial and " - "vice versa. This must be done manually." - ) - else: - await self._run_query( - _Table.alter().set_column_type( - old_column=old_column, - new_column=new_column, - using_expression=using_expression, - ) - ) - - ############################################################### - - on_delete = params.get("on_delete") - on_update = params.get("on_update") - if on_delete is not None or on_update is not None: - existing_table = await self.get_table_from_snapshot( - table_class_name=table_class_name, - app_name=self.app_name, + using_expression = "{}::{}".format( + alter_column.db_column_name, + new_column.column_type, ) - fk_column = existing_table._meta.get_column_by_name( - alter_column.column_name - ) - - assert isinstance(fk_column, ForeignKey) - - # First drop the existing foreign key constraint - constraint_name = await get_fk_constraint_name( - column=fk_column + if issubclass(column_class, Serial) and issubclass( + old_column_class, Serial + ): + colored_warning( + "Unable to migrate Serial to BigSerial and " + "vice versa. This must be done manually." ) - if constraint_name: - await self._run_query( - _Table.alter().drop_constraint( - constraint_name=constraint_name - ) - ) - - # Then add a new foreign key constraint + else: await self._run_query( - _Table.alter().add_foreign_key_constraint( - column=fk_column, - on_delete=on_delete, - on_update=on_update, + _Table.alter().set_column_type( + old_column=old_column, + new_column=new_column, + using_expression=using_expression, ) ) - null = params.get("null") - if null is not None: - await self._run_query( - _Table.alter().set_null( - column=alter_column.db_column_name, boolean=null - ) - ) + async def _run_alter_fk_constraints( + self, + alter_column: AlterColumn, + params: dict[str, Any], + table_class_name: str, + _Table: type[Table], + ): + on_delete = params.get("on_delete") + on_update = params.get("on_update") + if on_delete is not None or on_update is not None: + existing_table = await self.get_table_from_snapshot( + table_class_name=table_class_name, + app_name=self.app_name, + ) - length = params.get("length") - if length is not None: - await self._run_query( - _Table.alter().set_length( - column=alter_column.db_column_name, length=length - ) - ) + fk_column = existing_table._meta.get_column_by_name( + alter_column.column_name + ) - unique = params.get("unique") - if unique is not None: - # When modifying unique constraints, we need to pass in - # a column type, and not just the column name. - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = alter_column.db_column_name - await self._run_query( - _Table.alter().set_unique( - column=column, boolean=unique - ) + assert isinstance(fk_column, ForeignKey) + + constraint_name = await get_fk_constraint_name( + column=fk_column + ) + if constraint_name: + await self._run_query( + _Table.alter().drop_constraint( + constraint_name=constraint_name ) + ) - index = params.get("index") - index_method = params.get("index_method") - if index is None: - if index_method is not None: - # If the index value hasn't changed, but the - # index_method value has, this indicates we need - # to change the index type. - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = ( - alter_column.db_column_name - ) - await self._run_query(_Table.drop_index([column])) - await self._run_query( - _Table.create_index( - [column], - method=index_method, - if_not_exists=True, - ) - ) - else: - # If the index value has changed, then we are either - # dropping, or creating an index. - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = alter_column.db_column_name - - if index is True: - kwargs = ( - {"method": index_method} if index_method else {} - ) - await self._run_query( - _Table.create_index( - [column], if_not_exists=True, **kwargs - ) - ) - else: - await self._run_query(_Table.drop_index([column])) + await self._run_query( + _Table.alter().add_foreign_key_constraint( + column=fk_column, + on_delete=on_delete, + on_update=on_update, + ) + ) - # None is a valid value, so retrieve ellipsis if not found. - default = params.get("default", ...) - if default is not ...: - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = alter_column.db_column_name + async def _run_alter_nullable( + self, + alter_column: AlterColumn, + params: dict[str, Any], + _Table: type[Table], + ): + null = params.get("null") + if null is not None: + await self._run_query( + _Table.alter().set_null( + column=alter_column.db_column_name, boolean=null + ) + ) - if default is None: - await self._run_query( - _Table.alter().drop_default(column=column) - ) - else: - column.default = default - await self._run_query( - _Table.alter().set_default( - column=column, value=column.get_default_value() - ) - ) + async def _run_alter_length( + self, + alter_column: AlterColumn, + params: dict[str, Any], + _Table: type[Table], + ): + length = params.get("length") + if length is not None: + await self._run_query( + _Table.alter().set_length( + column=alter_column.db_column_name, length=length + ) + ) - # None is a valid value, so retrieve ellipsis if not found. - digits = params.get("digits", ...) - if digits is not ...: - await self._run_query( - _Table.alter().set_digits( - column=alter_column.db_column_name, - digits=digits, - ) + async def _run_alter_unique( + self, + alter_column: AlterColumn, + params: dict[str, Any], + _Table: type[Table], + ): + unique = params.get("unique") + if unique is not None: + column = Column() + column._meta._table = _Table + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + await self._run_query( + _Table.alter().set_unique( + column=column, boolean=unique + ) + ) + + async def _run_alter_index( + self, + alter_column: AlterColumn, + params: dict[str, Any], + _Table: type[Table], + ): + index = params.get("index") + index_method = params.get("index_method") + if index is None: + if index_method is not None: + column = Column() + column._meta._table = _Table + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + await self._run_query(_Table.drop_index([column])) + await self._run_query( + _Table.create_index( + [column], + method=index_method, + if_not_exists=True, + ) + ) + else: + column = Column() + column._meta._table = _Table + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + + if index is True: + kwargs = ( + {"method": index_method} if index_method else {} + ) + await self._run_query( + _Table.create_index( + [column], if_not_exists=True, **kwargs + ) + ) + else: + await self._run_query(_Table.drop_index([column])) + + async def _run_alter_default( + self, + alter_column: AlterColumn, + params: dict[str, Any], + _Table: type[Table], + ): + default = params.get("default", ...) + if default is not ...: + column = Column() + column._meta._table = _Table + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + + if default is None: + await self._run_query( + _Table.alter().drop_default(column=column) + ) + else: + column.default = default + await self._run_query( + _Table.alter().set_default( + column=column, value=column.get_default_value() ) + ) + + async def _run_alter_digits( + self, + alter_column: AlterColumn, + params: dict[str, Any], + _Table: type[Table], + ): + digits = params.get("digits", ...) + if digits is not ...: + await self._run_query( + _Table.alter().set_digits( + column=alter_column.db_column_name, + digits=digits, + ) + ) async def _run_drop_tables(self, backwards=False): for diffable_table in self.drop_tables: