diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 57b97fa2d..2aca667cd 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -456,213 +456,238 @@ async def _run_alter_columns(self, backwards: bool = False): else alter_column.old_params ) - ############################################################### - - # Change the column type if possible - column_class = ( - alter_column.old_column_class - if backwards - else alter_column.column_class + await self._run_alter_column_type( + _Table, alter_column, backwards, params, old_params, ) - old_column_class = ( - alter_column.column_class - if backwards - else alter_column.old_column_class + await self._run_alter_foreign_key( + _Table, alter_column, params, table_class_name ) + await self._run_alter_null(_Table, alter_column, params) + await self._run_alter_length(_Table, alter_column, params) + await self._run_alter_unique(_Table, alter_column, params) + await self._run_alter_index(_Table, alter_column, params) + await self._run_alter_default(_Table, alter_column, params) + await self._run_alter_digits(_Table, alter_column, params) - if (old_column_class is not None) and ( - column_class is not None - ): - if old_column_class != column_class: - old_column = old_column_class(**old_params) - old_column._meta._table = _Table - old_column._meta._name = alter_column.column_name - old_column._meta.db_column_name = ( - alter_column.db_column_name - ) - - new_column = column_class(**params) - new_column._meta._table = _Table - new_column._meta._name = alter_column.column_name - new_column._meta.db_column_name = ( - alter_column.db_column_name - ) - - using_expression: Optional[str] = None - - # Postgres won't automatically cast some types to - # others. We may as well try, as it will definitely - # fail otherwise. - if new_column.value_type != old_column.value_type: - if old_params.get("default", ...) is not None: - # Unless the column's default value is also - # something which can be cast to the new type, - # it will also fail. Drop the default value for - # now - the proper default is set later on. - await self._run_query( - _Table.alter().drop_default(old_column) - ) - - using_expression = "{}::{}".format( - alter_column.db_column_name, - new_column.column_type, - ) - - # We can't migrate a SERIAL to a BIGSERIAL or vice - # versa, as SERIAL isn't a true type, just an alias to - # other commands. - if issubclass(column_class, Serial) and issubclass( - old_column_class, Serial - ): - colored_warning( - "Unable to migrate Serial to BigSerial and " - "vice versa. This must be done manually." - ) - else: - await self._run_query( - _Table.alter().set_column_type( - old_column=old_column, - new_column=new_column, - using_expression=using_expression, - ) - ) + ########################################################################### - ############################################################### + @staticmethod + def _build_meta_column( + _Table: type[Table], + alter_column: AlterColumn, + ) -> Column: + column = Column() + column._meta._table = _Table + column._meta._name = alter_column.column_name + column._meta.db_column_name = alter_column.db_column_name + return column + + async def _run_alter_column_type( + self, + _Table: type[Table], + alter_column: AlterColumn, + backwards: bool, + params: dict, + old_params: dict, + ): + column_class = ( + alter_column.old_column_class + if backwards + else alter_column.column_class + ) + old_column_class = ( + alter_column.column_class + if backwards + else alter_column.old_column_class + ) - on_delete = params.get("on_delete") - on_update = params.get("on_update") - if on_delete is not None or on_update is not None: - existing_table = await self.get_table_from_snapshot( - table_class_name=table_class_name, - app_name=self.app_name, - ) + if (old_column_class is not None) and (column_class is not None): + if old_column_class != column_class: + old_column = old_column_class(**old_params) + old_column._meta._table = _Table + old_column._meta._name = alter_column.column_name + old_column._meta.db_column_name = alter_column.db_column_name - fk_column = existing_table._meta.get_column_by_name( - alter_column.column_name - ) + new_column = column_class(**params) + new_column._meta._table = _Table + new_column._meta._name = alter_column.column_name + new_column._meta.db_column_name = alter_column.db_column_name - assert isinstance(fk_column, ForeignKey) + using_expression: Optional[str] = None - # First drop the existing foreign key constraint - constraint_name = await get_fk_constraint_name( - column=fk_column - ) - if constraint_name: + if new_column.value_type != old_column.value_type: + if old_params.get("default", ...) is not None: await self._run_query( - _Table.alter().drop_constraint( - constraint_name=constraint_name - ) + _Table.alter().drop_default(old_column) ) - # Then add a new foreign key constraint - await self._run_query( - _Table.alter().add_foreign_key_constraint( - column=fk_column, - on_delete=on_delete, - on_update=on_update, - ) + using_expression = "{}::{}".format( + alter_column.db_column_name, + new_column.column_type, ) - null = params.get("null") - if null is not None: - await self._run_query( - _Table.alter().set_null( - column=alter_column.db_column_name, boolean=null - ) + if issubclass(column_class, Serial) and issubclass( + old_column_class, Serial + ): + colored_warning( + "Unable to migrate Serial to BigSerial and " + "vice versa. This must be done manually." ) - - length = params.get("length") - if length is not None: + else: await self._run_query( - _Table.alter().set_length( - column=alter_column.db_column_name, length=length + _Table.alter().set_column_type( + old_column=old_column, + new_column=new_column, + using_expression=using_expression, ) ) - unique = params.get("unique") - if unique is not None: - # When modifying unique constraints, we need to pass in - # a column type, and not just the column name. - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = alter_column.db_column_name - await self._run_query( - _Table.alter().set_unique( - column=column, boolean=unique - ) + async def _run_alter_foreign_key( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + table_class_name: str, + ): + on_delete = params.get("on_delete") + on_update = params.get("on_update") + if on_delete is not None or on_update is not None: + existing_table = await self.get_table_from_snapshot( + table_class_name=table_class_name, + app_name=self.app_name, + ) + + fk_column = existing_table._meta.get_column_by_name( + alter_column.column_name + ) + + assert isinstance(fk_column, ForeignKey) + + constraint_name = await get_fk_constraint_name(column=fk_column) + if constraint_name: + await self._run_query( + _Table.alter().drop_constraint( + constraint_name=constraint_name ) + ) - index = params.get("index") - index_method = params.get("index_method") - if index is None: - if index_method is not None: - # If the index value hasn't changed, but the - # index_method value has, this indicates we need - # to change the index type. - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = ( - alter_column.db_column_name - ) - await self._run_query(_Table.drop_index([column])) - await self._run_query( - _Table.create_index( - [column], - method=index_method, - if_not_exists=True, - ) - ) - else: - # If the index value has changed, then we are either - # dropping, or creating an index. - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = alter_column.db_column_name - - if index is True: - kwargs = ( - {"method": index_method} if index_method else {} - ) - await self._run_query( - _Table.create_index( - [column], if_not_exists=True, **kwargs - ) - ) - else: - await self._run_query(_Table.drop_index([column])) + await self._run_query( + _Table.alter().add_foreign_key_constraint( + column=fk_column, + on_delete=on_delete, + on_update=on_update, + ) + ) - # None is a valid value, so retrieve ellipsis if not found. - default = params.get("default", ...) - if default is not ...: - column = Column() - column._meta._table = _Table - column._meta._name = alter_column.column_name - column._meta.db_column_name = alter_column.db_column_name + async def _run_alter_null( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + ): + null = params.get("null") + if null is not None: + await self._run_query( + _Table.alter().set_null( + column=alter_column.db_column_name, boolean=null + ) + ) - if default is None: - await self._run_query( - _Table.alter().drop_default(column=column) - ) - else: - column.default = default - await self._run_query( - _Table.alter().set_default( - column=column, value=column.get_default_value() - ) - ) + async def _run_alter_length( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + ): + length = params.get("length") + if length is not None: + await self._run_query( + _Table.alter().set_length( + column=alter_column.db_column_name, length=length + ) + ) - # None is a valid value, so retrieve ellipsis if not found. - digits = params.get("digits", ...) - if digits is not ...: - await self._run_query( - _Table.alter().set_digits( - column=alter_column.db_column_name, - digits=digits, - ) + async def _run_alter_unique( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + ): + unique = params.get("unique") + if unique is not None: + column = self._build_meta_column(_Table, alter_column) + await self._run_query( + _Table.alter().set_unique(column=column, boolean=unique) + ) + + async def _run_alter_index( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + ): + index = params.get("index") + index_method = params.get("index_method") + if index is None: + if index_method is not None: + column = self._build_meta_column(_Table, alter_column) + await self._run_query(_Table.drop_index([column])) + await self._run_query( + _Table.create_index( + [column], + method=index_method, + if_not_exists=True, + ) + ) + else: + column = self._build_meta_column(_Table, alter_column) + if index is True: + kwargs = ( + {"method": index_method} if index_method else {} + ) + await self._run_query( + _Table.create_index( + [column], if_not_exists=True, **kwargs ) + ) + else: + await self._run_query(_Table.drop_index([column])) + + async def _run_alter_default( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + ): + default = params.get("default", ...) + if default is not ...: + column = self._build_meta_column(_Table, alter_column) + if default is None: + await self._run_query( + _Table.alter().drop_default(column=column) + ) + else: + column.default = default + await self._run_query( + _Table.alter().set_default( + column=column, value=column.get_default_value() + ) + ) + + async def _run_alter_digits( + self, + _Table: type[Table], + alter_column: AlterColumn, + params: dict, + ): + digits = params.get("digits", ...) + if digits is not ...: + await self._run_query( + _Table.alter().set_digits( + column=alter_column.db_column_name, + digits=digits, + ) + ) async def _run_drop_tables(self, backwards=False): for diffable_table in self.drop_tables: diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index e3d166353..585cc3051 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -522,267 +522,286 @@ def __repr__(self): ############################################################################### -def serialise_params( - params: dict[str, Any], inline_enums: bool = True -) -> SerialisedParams: - """ - When writing column params to a migration file, or outputting to the - playground, we need to serialise some of the values. - - :param inline_enums: - If ``True``, enum value are inlined, for example:: - - value=Enum('MyEnum', {'some_value': 'some_value'})) - - Otherwise, it is reproduced as:: - - value=MyEnum - - And the enum definition is added to - ``SerialisedParams.extra_definitions``. - - """ - params = deepcopy(params) - extra_imports: list[Import] = [] - extra_definitions: list[Definition] = [] - - for key, value in params.items(): - # Builtins, such as str, list and dict. - if inspect.getmodule(value) == builtins: - params[key] = SerialisedBuiltin(builtin=value) - continue - - # Column instances - if isinstance(value, Column): - # For target_column (which is used by ForeignKey), we can just - # serialise it as the column name: - if key == "target_column": - params[key] = value._meta.name - continue - - ################################################################### - - # For Array definitions, we want to serialise the full column - # definition: - - column: Column = value - serialised_params: SerialisedParams = serialise_params( - params=column._meta.params +def _handle_builtin(key, value, params, extra_imports, extra_definitions, inline_enums): + if inspect.getmodule(value) == builtins: + params[key] = SerialisedBuiltin(builtin=value) + return True + return False + + +def _handle_column(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, Column): + if key == "target_column": + params[key] = value._meta.name + return True + + column: Column = value + serialised_params: SerialisedParams = serialise_params( + params=column._meta.params + ) + extra_imports.extend(serialised_params.extra_imports) + extra_definitions.extend(serialised_params.extra_definitions) + + column_class_name = column.__class__.__name__ + extra_imports.append( + Import( + module=column.__class__.__module__, + target=column_class_name, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{column_class_name.upper()}", + None, + ), + ) + ) + params[key] = SerialisedColumnInstance( + instance=value, serialised_params=serialised_params + ) + return True + return False + + +def _handle_default(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, Default): + params[key] = SerialisedClassInstance(instance=value) + extra_imports.append( + Import( + module=value.__class__.__module__, + target=value.__class__.__name__, + expect_conflict_with_global_name=UniqueGlobalNames.DEFAULT, ) + ) + return True + return False - # Include the extra imports and definitions required for the - # column params. - extra_imports.extend(serialised_params.extra_imports) - extra_definitions.extend(serialised_params.extra_definitions) - column_class_name = column.__class__.__name__ - extra_imports.append( - Import( - module=column.__class__.__module__, - target=column_class_name, - expect_conflict_with_global_name=getattr( - UniqueGlobalNames, - f"COLUMN_{column_class_name.upper()}", - None, - ), - ) +def _handle_datetime(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, (datetime.time, datetime.datetime, datetime.date)): + extra_imports.append( + Import( + module=value.__class__.__module__, + target=value.__class__.__name__, ) - params[key] = SerialisedColumnInstance( - instance=value, serialised_params=serialised_params + ) + return True + return False + + +def _handle_uuid(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, uuid.UUID): + params[key] = SerialisedUUID(instance=value) + extra_imports.append( + Import( + module=UniqueGlobalNames.EXTERNAL_MODULE_UUID, + expect_conflict_with_global_name=( + UniqueGlobalNames.EXTERNAL_MODULE_UUID + ), ) - continue - - # Class instances - if isinstance(value, Default): - params[key] = SerialisedClassInstance(instance=value) - extra_imports.append( - Import( - module=value.__class__.__module__, - target=value.__class__.__name__, - expect_conflict_with_global_name=UniqueGlobalNames.DEFAULT, - ) + ) + return True + return False + + +def _handle_decimal(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, decimal.Decimal): + params[key] = SerialisedDecimal(instance=value) + extra_imports.append( + Import( + module=UniqueGlobalNames.STD_LIB_MODULE_DECIMAL, + expect_conflict_with_global_name=( + UniqueGlobalNames.STD_LIB_MODULE_DECIMAL + ), ) - continue + ) + return True + return False - # Dates and times - if isinstance( - value, (datetime.time, datetime.datetime, datetime.date) - ): - # Already has a good __repr__. + +def _handle_enum_instance(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, Enum): + if value.__module__.startswith("piccolo"): + params[key] = SerialisedEnumInstance(instance=value) extra_imports.append( Import( - module=value.__class__.__module__, + module=value.__module__, target=value.__class__.__name__, ) ) - continue - - # UUIDs - if isinstance(value, uuid.UUID): - params[key] = SerialisedUUID(instance=value) - extra_imports.append( - Import( - module=UniqueGlobalNames.EXTERNAL_MODULE_UUID, - expect_conflict_with_global_name=( - UniqueGlobalNames.EXTERNAL_MODULE_UUID - ), - ) + else: + enum_serialised_params: SerialisedParams = serialise_params( + params={key: value.value} ) - continue - - # Decimals - if isinstance(value, decimal.Decimal): - params[key] = SerialisedDecimal(instance=value) - extra_imports.append( - Import( - module=UniqueGlobalNames.STD_LIB_MODULE_DECIMAL, - expect_conflict_with_global_name=( - UniqueGlobalNames.STD_LIB_MODULE_DECIMAL - ), - ) + params[key] = enum_serialised_params.params[key] + extra_imports.extend(enum_serialised_params.extra_imports) + extra_definitions.extend( + enum_serialised_params.extra_definitions + ) + return True + return False + + +def _handle_enum_type(key, value, params, extra_imports, extra_definitions, inline_enums): + if inspect.isclass(value) and issubclass(value, Enum): + extra_imports.append( + Import( + module="enum", + target=UniqueGlobalNames.STD_LIB_ENUM, + expect_conflict_with_global_name=( + UniqueGlobalNames.STD_LIB_ENUM + ), ) - continue - - # Enum instances - if isinstance(value, Enum): - if value.__module__.startswith("piccolo"): - # It's an Enum defined within Piccolo, so we can safely import - # it. - params[key] = SerialisedEnumInstance(instance=value) + ) + for member in value: + type_ = type(member.value) + module = inspect.getmodule(type_) + + if module and module != builtins: + module_name = module.__name__ extra_imports.append( - Import( - module=value.__module__, - target=value.__class__.__name__, - ) - ) - else: - # It's a user defined Enum, so we'll insert the raw value. - enum_serialised_params: SerialisedParams = serialise_params( - params={key: value.value} - ) - params[key] = enum_serialised_params.params[key] - extra_imports.extend(enum_serialised_params.extra_imports) - extra_definitions.extend( - enum_serialised_params.extra_definitions + Import(module=module_name, target=type_.__name__) ) - continue - - # Enum types - if inspect.isclass(value) and issubclass(value, Enum): - extra_imports.append( - Import( - module="enum", - target=UniqueGlobalNames.STD_LIB_ENUM, - expect_conflict_with_global_name=( - UniqueGlobalNames.STD_LIB_ENUM - ), - ) + if inline_enums: + params[key] = InlineSerialisedEnumType(enum_type=value) + else: + params[key] = SerialisedReference(name=value.__name__) + extra_definitions.append( + SerialisedEnumTypeDefinition(enum_type=value) ) - for member in value: - type_ = type(member.value) - module = inspect.getmodule(type_) - - if module and module != builtins: - module_name = module.__name__ - extra_imports.append( - Import(module=module_name, target=type_.__name__) - ) + return True + return False - if inline_enums: - params[key] = InlineSerialisedEnumType(enum_type=value) - else: - params[key] = SerialisedReference(name=value.__name__) - extra_definitions.append( - SerialisedEnumTypeDefinition(enum_type=value) - ) - # Functions - if inspect.isfunction(value): - if value.__name__ == "": - raise ValueError("Lambdas can't be serialised") +def _handle_function(key, value, params, extra_imports, extra_definitions, inline_enums): + if inspect.isfunction(value): + if value.__name__ == "": + raise ValueError("Lambdas can't be serialised") - params[key] = SerialisedCallable(callable_=value) - extra_imports.append( - Import(module=value.__module__, target=value.__name__) - ) - continue + params[key] = SerialisedCallable(callable_=value) + extra_imports.append( + Import(module=value.__module__, target=value.__name__) + ) + return True + return False - # Lazy imports - we need to resolve these now, in case the target - # table class gets deleted in the future. - if isinstance(value, LazyTableReference): - table_type = value.resolve() - params[key] = SerialisedCallable(callable_=table_type) - extra_definitions.append( - SerialisedTableType(table_type=table_type) + +def _handle_lazy_table_reference(key, value, params, extra_imports, extra_definitions, inline_enums): + if isinstance(value, LazyTableReference): + table_type = value.resolve() + params[key] = SerialisedCallable(callable_=table_type) + extra_definitions.append( + SerialisedTableType(table_type=table_type) + ) + extra_imports.append( + Import( + module=Table.__module__, + target=UniqueGlobalNames.TABLE, + expect_conflict_with_global_name=UniqueGlobalNames.TABLE, ) - extra_imports.append( - Import( - module=Table.__module__, - target=UniqueGlobalNames.TABLE, - expect_conflict_with_global_name=UniqueGlobalNames.TABLE, - ) + ) + primary_key_class = table_type._meta.primary_key.__class__ + extra_imports.append( + Import( + module=primary_key_class.__module__, + target=primary_key_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{primary_key_class.__name__.upper()}", + None, + ), ) - # also add missing primary key to extra_imports when creating a - # migration with a ForeignKey that uses a LazyTableReference - # https://github.com/piccolo-orm/piccolo/issues/865 - primary_key_class = table_type._meta.primary_key.__class__ - extra_imports.append( - Import( - module=primary_key_class.__module__, - target=primary_key_class.__name__, - expect_conflict_with_global_name=getattr( - UniqueGlobalNames, - f"COLUMN_{primary_key_class.__name__.upper()}", - None, - ), - ) + ) + return True + return False + + +def _handle_table(key, value, params, extra_imports, extra_definitions, inline_enums): + if inspect.isclass(value) and issubclass(value, Table): + params[key] = SerialisedCallable(callable_=value) + extra_definitions.append(SerialisedTableType(table_type=value)) + extra_imports.append( + Import( + module=Table.__module__, + target=UniqueGlobalNames.TABLE, + expect_conflict_with_global_name=UniqueGlobalNames.TABLE, ) - continue + ) - # Replace any Table class values into class and table names - if inspect.isclass(value) and issubclass(value, Table): - params[key] = SerialisedCallable(callable_=value) - extra_definitions.append(SerialisedTableType(table_type=value)) - extra_imports.append( - Import( - module=Table.__module__, - target=UniqueGlobalNames.TABLE, - expect_conflict_with_global_name=UniqueGlobalNames.TABLE, - ) + primary_key_class = value._meta.primary_key.__class__ + extra_imports.append( + Import( + module=primary_key_class.__module__, + target=primary_key_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{primary_key_class.__name__.upper()}", + None, + ), ) + ) + pk_serialised_params: SerialisedParams = serialise_params( + params=value._meta.primary_key._meta.params + ) + extra_imports.extend(pk_serialised_params.extra_imports) + extra_definitions.extend(pk_serialised_params.extra_definitions) + return True + return False - primary_key_class = value._meta.primary_key.__class__ - extra_imports.append( - Import( - module=primary_key_class.__module__, - target=primary_key_class.__name__, - expect_conflict_with_global_name=getattr( - UniqueGlobalNames, - f"COLUMN_{primary_key_class.__name__.upper()}", - None, - ), - ) - ) - # Include the extra imports and definitions required for the - # primary column params. - pk_serialised_params: SerialisedParams = serialise_params( - params=value._meta.primary_key._meta.params - ) - extra_imports.extend(pk_serialised_params.extra_imports) - extra_definitions.extend(pk_serialised_params.extra_definitions) - continue +def _handle_plain_class(key, value, params, extra_imports, extra_definitions, inline_enums): + if inspect.isclass(value) and not issubclass(value, Enum): + params[key] = SerialisedCallable(callable_=value) + extra_imports.append( + Import(module=value.__module__, target=value.__name__) + ) + return True + return False + + +HANDLERS = [ + _handle_builtin, + _handle_column, + _handle_default, + _handle_datetime, + _handle_uuid, + _handle_decimal, + _handle_enum_instance, + _handle_enum_type, + _handle_function, + _handle_lazy_table_reference, + _handle_table, + _handle_plain_class, +] - # Plain class type - if inspect.isclass(value) and not issubclass(value, Enum): - params[key] = SerialisedCallable(callable_=value) - extra_imports.append( - Import(module=value.__module__, target=value.__name__) - ) - continue - # All other types can remain as is. +def serialise_params( + params: dict[str, Any], inline_enums: bool = True +) -> SerialisedParams: + """ + When writing column params to a migration file, or outputting to the + playground, we need to serialise some of the values. + + :param inline_enums: + If ``True``, enum value are inlined, for example:: + + value=Enum('MyEnum', {'some_value': 'some_value'})) + + Otherwise, it is reproduced as:: + + value=MyEnum + + And the enum definition is added to + ``SerialisedParams.extra_definitions``. + + """ + params = deepcopy(params) + extra_imports: list[Import] = [] + extra_definitions: list[Definition] = [] + + for key, value in params.items(): + for handler in HANDLERS: + if handler(key, value, params, extra_imports, extra_definitions, inline_enums): + break unique_extra_imports = list(set(extra_imports)) UniqueGlobalNames.warn_if_are_conflicting_objects(unique_extra_imports) diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index d4e94ef16..969dfe1d4 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -714,80 +714,31 @@ async def create_table_class_from_db( # column_type = BigSerial if constraints.is_foreign_key(column_name=column_name): - fk_constraint_table = constraints.get_foreign_key_constraint_name( - column_name=column_name - ) - column_type = ForeignKey - constraint_table = await get_foreign_key_reference( - table_class=table_class, - constraint_name=fk_constraint_table.name, - constraint_schema=fk_constraint_table.schema, - ) - if constraint_table.name: - referenced_table: Union[str, Optional[type[Table]]] - - if constraint_table.name == tablename: - referenced_output_schema = output_schema - referenced_table = "self" - else: - referenced_output_schema = ( - await create_table_class_from_db( - table_class=table_class, - tablename=constraint_table.name, - schema_name=constraint_table.schema, - engine_type=engine_type, - ) - ) - referenced_table = ( - referenced_output_schema.get_table_with_name( - tablename=constraint_table.name - ) - ) - kwargs["references"] = ( - referenced_table - if referenced_table is not None - else ForeignKeyPlaceholder + column_type, kwargs, output_schema = ( + await _process_foreign_key_column( + table_class=table_class, + column_name=column_name, + tablename=tablename, + engine_type=engine_type, + output_schema=output_schema, + constraints=constraints, + triggers=triggers, + kwargs=kwargs, ) - - trigger = triggers.get_column_ref_trigger( - column_name, constraint_table.name - ) - if trigger: - kwargs["on_update"] = OnUpdate(trigger.on_update) - kwargs["on_delete"] = OnDelete(trigger.on_delete) - else: - output_schema.trigger_warnings.append( - f"{tablename}.{column_name}" - ) - - output_schema = sum( # type: ignore - [output_schema, referenced_output_schema] # type: ignore - ) # type: ignore - else: - kwargs["references"] = ForeignKeyPlaceholder + ) output_schema.imports.append( "from piccolo.columns.column_types import " + column_type.__name__ # type: ignore ) - if column_type is Varchar: - kwargs["length"] = pg_row_meta.character_maximum_length - elif isinstance(column_type, Numeric): - radix = pg_row_meta.numeric_precision_radix - if radix: - precision = int(str(pg_row_meta.numeric_precision), radix) - scale = int(str(pg_row_meta.numeric_scale), radix) - kwargs["digits"] = (precision, scale) - else: - kwargs["digits"] = None - - if column_default: - default_value = get_column_default( - column_type, column_default, engine_type - ) - if default_value: - kwargs["default"] = default_value + kwargs = _process_column_type_kwargs( + column_type=column_type, + pg_row_meta=pg_row_meta, + column_default=column_default, + engine_type=engine_type, + kwargs=kwargs, + ) column = column_type(**kwargs) # type: ignore @@ -806,6 +757,99 @@ async def create_table_class_from_db( return output_schema +async def _process_foreign_key_column( # pylint: disable=too-many-arguments,too-many-positional-arguments + table_class: type[Table], + column_name: str, + tablename: str, + engine_type: str, + output_schema: OutputSchema, + constraints: TableConstraints, + triggers: TableTriggers, + kwargs: dict[str, Any], +) -> tuple[type[Column], dict[str, Any], OutputSchema]: + fk_constraint_table = constraints.get_foreign_key_constraint_name( + column_name=column_name + ) + column_type: type[Column] = ForeignKey + constraint_table = await get_foreign_key_reference( + table_class=table_class, + constraint_name=fk_constraint_table.name, + constraint_schema=fk_constraint_table.schema, + ) + if constraint_table.name: + referenced_table: Union[str, Optional[type[Table]]] + + if constraint_table.name == tablename: + referenced_output_schema = output_schema + referenced_table = "self" + else: + referenced_output_schema = ( + await create_table_class_from_db( + table_class=table_class, + tablename=constraint_table.name, + schema_name=constraint_table.schema, + engine_type=engine_type, + ) + ) + referenced_table = ( + referenced_output_schema.get_table_with_name( + tablename=constraint_table.name + ) + ) + kwargs["references"] = ( + referenced_table + if referenced_table is not None + else ForeignKeyPlaceholder + ) + + trigger = triggers.get_column_ref_trigger( + column_name, constraint_table.name + ) + if trigger: + kwargs["on_update"] = OnUpdate(trigger.on_update) + kwargs["on_delete"] = OnDelete(trigger.on_delete) + else: + output_schema.trigger_warnings.append( + f"{tablename}.{column_name}" + ) + + output_schema = sum( # type: ignore + [output_schema, referenced_output_schema] # type: ignore + ) # type: ignore + else: + kwargs["references"] = ForeignKeyPlaceholder + + return column_type, kwargs, output_schema + + +def _process_column_type_kwargs( + column_type: type[Column], + pg_row_meta: RowMeta, + column_default: Optional[str], + engine_type: str, + kwargs: dict[str, Any], +) -> dict[str, Any]: + if column_type is Varchar: + kwargs["length"] = pg_row_meta.character_maximum_length + elif isinstance(column_type, Numeric): + radix = pg_row_meta.numeric_precision_radix + if radix: + precision = int(str(pg_row_meta.numeric_precision), radix) + scale = int(str(pg_row_meta.numeric_scale), radix) + kwargs["digits"] = (precision, scale) + else: + kwargs["digits"] = None + + if column_default: + default_value = get_column_default( + column_type, column_default, engine_type + ) + if default_value: + kwargs["default"] = default_value + + return kwargs + + async def get_output_schema( schema_name: str = "public", include: Optional[list[str]] = None, diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 4266c3c7d..b82b466a3 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -568,9 +568,9 @@ def _check_valid_call_chain(self, keys: Sequence[Selectable]) -> bool: ) return True - @property - def default_querystrings(self) -> Sequence[QueryString]: - # JOIN + def _resolve_joined_columns( + self, + ) -> tuple[list[str], list[QueryString], str]: self._check_valid_call_chain(self.columns_delegate.selected_columns) select_joins = self._get_joins(self.columns_delegate.selected_columns) @@ -582,21 +582,15 @@ def default_querystrings(self) -> Sequence[QueryString]: self.order_by_delegate.get_order_by_columns() ) - # Combine all joins, and remove duplicates joins: list[str] = list( OrderedDict.fromkeys( select_joins + where_joins + having_joins + order_by_joins ) ) - ####################################################################### - - # If no columns have been specified for selection, select all columns - # on the table: if len(self.columns_delegate.selected_columns) == 0: self.columns_delegate.selected_columns = self.table._meta.columns - # If secret fields need to be omitted, remove them from the list. if self.exclude_secrets: self.columns_delegate.remove_secret_columns() @@ -607,8 +601,14 @@ def default_querystrings(self) -> Sequence[QueryString]: for c in self.columns_delegate.selected_columns ] - ####################################################################### + return joins, select_strings, engine_type + def _build_querystring( + self, + joins: list[str], + select_strings: list[QueryString], + engine_type: str, + ) -> QueryString: args: list[Any] = [] query = "SELECT" @@ -620,7 +620,9 @@ def default_querystrings(self) -> Sequence[QueryString]: args.append(distinct.querystring) columns_str = ", ".join("{}" for _ in select_strings) - query += f" {columns_str} FROM {self.table._meta.get_formatted_tablename()}" # noqa: E501 + query += ( + f" {columns_str} FROM {self.table._meta.get_formatted_tablename()}" + ) args.extend(select_strings) for join in joins: @@ -674,8 +676,12 @@ def default_querystrings(self) -> Sequence[QueryString]: query += "{}" args.append(self.lock_rows_delegate._lock_rows.querystring) - querystring = QueryString(query, *args) + return QueryString(query, *args) + @property + def default_querystrings(self) -> Sequence[QueryString]: + joins, select_strings, engine_type = self._resolve_joined_columns() + querystring = self._build_querystring(joins, select_strings, engine_type) return [querystring] async def run( diff --git a/piccolo/table.py b/piccolo/table.py index 8ff955d59..d6a3528a7 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -206,6 +206,21 @@ def get_auto_update_values(self) -> dict[Column, Any]: return output +@dataclass +class _ColumnCollection: + columns: list[Column] = field(default_factory=list) + default_columns: list[Column] = field(default_factory=list) + non_default_columns: list[Column] = field(default_factory=list) + array_columns: list[Array] = field(default_factory=list) + foreign_key_columns: list[ForeignKey] = field(default_factory=list) + secret_columns: list[Column] = field(default_factory=list) + json_columns: list[Union[JSON, JSONB]] = field(default_factory=list) + email_columns: list[Email] = field(default_factory=list) + auto_update_columns: list[Column] = field(default_factory=list) + primary_key: Optional[Column] = None + m2m_relationships: list[M2M] = field(default_factory=list) + + class TableMetaclass(type): def __str__(cls) -> str: return cls._table_str() # type: ignore @@ -245,49 +260,8 @@ class Table(metaclass=TableMetaclass): # actual values are set in __init_subclass__. _meta = TableMeta() - def __init_subclass__( - cls, - tablename: Optional[str] = None, - db: Optional[Engine] = None, - tags: Optional[list[str]] = None, - help_text: Optional[str] = None, - schema: Optional[str] = None, - ): # sourcery no-metrics - """ - Automatically populate the _meta, which includes the tablename, and - columns. - - :param tablename: - Specify a custom tablename. By default the classname is converted - to snakecase. - :param db: - Manually specify an engine to use for connecting to the database. - Useful when writing simple scripts. If not set, the engine is - imported from piccolo_conf.py using ``engine_finder``. - :param tags: - Used for filtering, for example by ``table_finder``. - :param help_text: - A user friendly description of what the table is used for. It isn't - used in the database, but will be used by tools such a Piccolo - Admin for tooltips. - :param schema: - The Postgres schema to use for this table. - - """ - if tags is None: - tags = [] - tablename = tablename or _camel_to_snake(cls.__name__) - - if "." in tablename: - warnings.warn( - "There's a '.' in the tablename - please use the `schema` " - "argument instead." - ) - schema, tablename = tablename.split(".", maxsplit=1) - - if tablename in PROTECTED_TABLENAMES: - warnings.warn(TABLENAME_WARNING.format(tablename=tablename)) - + @classmethod + def _collect_columns(cls) -> _ColumnCollection: columns: list[Column] = [] default_columns: list[Column] = [] non_default_columns: list[Column] = [] @@ -311,10 +285,6 @@ def __init_subclass__( attribute = getattr(cls, attribute_name) if isinstance(attribute, Column): - # We have to copy, then override the existing column - # definition, in case this column is inheritted from a mixin. - # Otherwise, when we set attributes on that column, it will - # effect all other users of that mixin. column = attribute.copy() setattr(cls, attribute_name, column) @@ -351,35 +321,92 @@ def __init_subclass__( attribute._meta._table = cls m2m_relationships.append(attribute) - if not primary_key: - primary_key = cls._create_serial_primary_key() - setattr(cls, "id", primary_key) - - columns.insert(0, primary_key) # PK should be added first - default_columns.append(primary_key) - - cls._meta = TableMeta( - tablename=tablename, + return _ColumnCollection( columns=columns, default_columns=default_columns, non_default_columns=non_default_columns, array_columns=array_columns, - email_columns=email_columns, - primary_key=primary_key, foreign_key_columns=foreign_key_columns, - json_columns=json_columns, secret_columns=secret_columns, + json_columns=json_columns, + email_columns=email_columns, auto_update_columns=auto_update_columns, + primary_key=primary_key, + m2m_relationships=m2m_relationships, + ) + + def __init_subclass__( + cls, + tablename: Optional[str] = None, + db: Optional[Engine] = None, + tags: Optional[list[str]] = None, + help_text: Optional[str] = None, + schema: Optional[str] = None, + ): + """ + Automatically populate the _meta, which includes the tablename, and + columns. + + :param tablename: + Specify a custom tablename. By default the classname is converted + to snakecase. + :param db: + Manually specify an engine to use for connecting to the database. + Useful when writing simple scripts. If not set, the engine is + imported from piccolo_conf.py using ``engine_finder``. + :param tags: + Used for filtering, for example by ``table_finder``. + :param help_text: + A user friendly description of what the table is used for. It isn't + used in the database, but will be used by tools such a Piccolo + Admin for tooltips. + :param schema: + The Postgres schema to use for this table. + + """ + if tags is None: + tags = [] + tablename = tablename or _camel_to_snake(cls.__name__) + + if "." in tablename: + warnings.warn( + "There's a '.' in the tablename - please use the `schema` " + "argument instead." + ) + schema, tablename = tablename.split(".", maxsplit=1) + + if tablename in PROTECTED_TABLENAMES: + warnings.warn(TABLENAME_WARNING.format(tablename=tablename)) + + collected = cls._collect_columns() + + if not collected.primary_key: + primary_key = cls._create_serial_primary_key() + setattr(cls, "id", primary_key) + collected.columns.insert(0, primary_key) + collected.default_columns.append(primary_key) + collected.primary_key = primary_key + + cls._meta = TableMeta( + tablename=tablename, + columns=collected.columns, + default_columns=collected.default_columns, + non_default_columns=collected.non_default_columns, + array_columns=collected.array_columns, + email_columns=collected.email_columns, + primary_key=collected.primary_key, + foreign_key_columns=collected.foreign_key_columns, + json_columns=collected.json_columns, + secret_columns=collected.secret_columns, + auto_update_columns=collected.auto_update_columns, tags=tags, help_text=help_text, _db=db, - m2m_relationships=m2m_relationships, + m2m_relationships=collected.m2m_relationships, schema=schema, ) - for foreign_key_column in foreign_key_columns: - # ForeignKey columns require additional setup based on their - # parent Table. + for foreign_key_column in collected.foreign_key_columns: foreign_key_setup_response = foreign_key_column._setup( table_class=cls ) diff --git a/piccolo/utils/pydantic.py b/piccolo/utils/pydantic.py index 794406f17..902a7586d 100644 --- a/piccolo/utils/pydantic.py +++ b/piccolo/utils/pydantic.py @@ -108,6 +108,177 @@ def get_pydantic_value_type(column: Column) -> type: return value_type +def _filter_piccolo_columns( + table: type[Table], + include_default_columns: bool, + include_columns: tuple[Column, ...], + exclude_columns: tuple[Column, ...], +) -> tuple[Column, ...]: + piccolo_columns = tuple( + table._meta.columns + if include_default_columns + else table._meta.non_default_columns + ) + + if include_columns: + include_columns_plus_ancestors = list( + itertools.chain( + include_columns, *[i._meta.call_chain for i in include_columns] + ) + ) + piccolo_columns = tuple( + i + for i in piccolo_columns + if any( + i._equals(include_column) + for include_column in include_columns_plus_ancestors + ) + ) + + if exclude_columns: + piccolo_columns = tuple( + i + for i in piccolo_columns + if not any( + i._equals(exclude_column) for exclude_column in exclude_columns + ) + ) + + return piccolo_columns + + +def _process_column( + column: Column, + model_name: str, + all_optional: bool, + deserialize_json: bool, + nested: Union[bool, tuple[ForeignKey, ...]], + include_readable: bool, + recursion_depth: int, + max_recursion_depth: int, + include_columns: tuple[Column, ...], + exclude_columns: tuple[Column, ...], + include_default_columns: bool, +) -> tuple[list[tuple[str, Any]], dict[str, Callable]]: + column_name = column._meta.name + is_optional = True if all_optional else not column._meta.required + validators: dict[str, Callable] = {} + + if isinstance(column, (JSON, JSONB)): + if deserialize_json: + value_type = pydantic.Json + else: + value_type = column.value_type + validator = partial( + pydantic_json_validator, required=not is_optional + ) + validators[ + f"{column_name}_is_json" + ] = pydantic.field_validator(column_name)( + validator # type: ignore + ) + else: + value_type = get_pydantic_value_type(column=column) + + _type = Optional[value_type] if is_optional else value_type + + params: dict[str, Any] = {} + if is_optional: + params["default"] = None + + if column._meta.db_column_name != column._meta.name: + params["alias"] = column._meta.db_column_name + + extra: JsonDict = { + "help_text": column._meta.help_text, + "choices": column._meta.get_choices_dict(), + "secret": column._meta.secret, + "nullable": column._meta.null, + "unique": column._meta.unique, + } + + if isinstance(column, ForeignKey): + if recursion_depth < max_recursion_depth and ( + (nested is True) + or ( + isinstance(nested, tuple) + and any( + column._equals(i) + for i in itertools.chain( + nested, *[i._meta.call_chain for i in nested] + ) + ) + ) + ): + nested_model_name = f"{model_name}.{column._meta.name}" + _type = create_pydantic_model( + table=column._foreign_key_meta.resolved_references, + nested=nested, + include_columns=include_columns, + exclude_columns=exclude_columns, + include_default_columns=include_default_columns, + include_readable=include_readable, + all_optional=all_optional, + deserialize_json=deserialize_json, + recursion_depth=recursion_depth + 1, + max_recursion_depth=max_recursion_depth, + model_name=nested_model_name, + ) + + tablename = ( + column._foreign_key_meta.resolved_references._meta.tablename + ) + target_column = ( + column._foreign_key_meta.resolved_target_column._meta.name + ) + extra["foreign_key"] = { + "to": tablename, + "target_column": target_column, + } + else: + if isinstance(column, Text): + extra["widget"] = "text-area" + elif isinstance(column, (JSON, JSONB)): + extra["widget"] = "json" + elif isinstance(column, Timestamptz): + extra["widget"] = "timestamptz" + + if isinstance(column, Array): + extra["dimensions"] = column._get_dimensions() + + field = pydantic.Field( + json_schema_extra={"extra": extra}, + **params, + ) + + entries: list[tuple[str, Any]] = [(column_name, (_type, field))] + + if isinstance(column, ForeignKey) and include_readable: + entries.append((f"{column_name}_readable", (str, None))) + + return entries, validators + + +def _build_pydantic_config( + pydantic_config: Optional[pydantic.config.ConfigDict], + json_schema_extra: Optional[dict[str, Any]], + table: type[Table], +) -> pydantic.config.ConfigDict: + pydantic_config = ( + pydantic_config.copy() + if pydantic_config + else pydantic.config.ConfigDict() + ) + pydantic_config["arbitrary_types_allowed"] = True + + json_schema_extra_ = defaultdict(dict, **(json_schema_extra or {})) + json_schema_extra_["extra"]["help_text"] = table._meta.help_text + + pydantic_config["json_schema_extra"] = dict(json_schema_extra_) + + return pydantic_config + + def create_pydantic_model( table: type[Table], nested: Union[bool, tuple[ForeignKey, ...]] = False, @@ -205,155 +376,38 @@ def create_pydantic_model( columns: dict[str, Any] = {} validators: dict[str, Callable] = {} - piccolo_columns = tuple( - table._meta.columns - if include_default_columns - else table._meta.non_default_columns + piccolo_columns = _filter_piccolo_columns( + table=table, + include_default_columns=include_default_columns, + include_columns=include_columns, + exclude_columns=exclude_columns, ) - if include_columns: - include_columns_plus_ancestors = list( - itertools.chain( - include_columns, *[i._meta.call_chain for i in include_columns] - ) - ) - piccolo_columns = tuple( - i - for i in piccolo_columns - if any( - i._equals(include_column) - for include_column in include_columns_plus_ancestors - ) - ) - - if exclude_columns: - piccolo_columns = tuple( - i - for i in piccolo_columns - if not any( - i._equals(exclude_column) for exclude_column in exclude_columns - ) - ) - model_name = model_name or table.__name__ for column in piccolo_columns: - column_name = column._meta.name - - is_optional = True if all_optional else not column._meta.required - - ####################################################################### - # Work out the column type - - if isinstance(column, (JSON, JSONB)): - if deserialize_json: - value_type = pydantic.Json - else: - value_type = column.value_type - validator = partial( - pydantic_json_validator, required=not is_optional - ) - validators[ - f"{column_name}_is_json" - ] = pydantic.field_validator(column_name)( - validator # type: ignore - ) - else: - value_type = get_pydantic_value_type(column=column) - - _type = Optional[value_type] if is_optional else value_type - - ####################################################################### - - params: dict[str, Any] = {} - if is_optional: - params["default"] = None - - if column._meta.db_column_name != column._meta.name: - params["alias"] = column._meta.db_column_name - - extra: JsonDict = { - "help_text": column._meta.help_text, - "choices": column._meta.get_choices_dict(), - "secret": column._meta.secret, - "nullable": column._meta.null, - "unique": column._meta.unique, - } - - if isinstance(column, ForeignKey): - if recursion_depth < max_recursion_depth and ( - (nested is True) - or ( - isinstance(nested, tuple) - and any( - column._equals(i) - for i in itertools.chain( - nested, *[i._meta.call_chain for i in nested] - ) - ) - ) - ): - nested_model_name = f"{model_name}.{column._meta.name}" - _type = create_pydantic_model( - table=column._foreign_key_meta.resolved_references, - nested=nested, - include_columns=include_columns, - exclude_columns=exclude_columns, - include_default_columns=include_default_columns, - include_readable=include_readable, - all_optional=all_optional, - deserialize_json=deserialize_json, - recursion_depth=recursion_depth + 1, - max_recursion_depth=max_recursion_depth, - model_name=nested_model_name, - ) - - tablename = ( - column._foreign_key_meta.resolved_references._meta.tablename - ) - target_column = ( - column._foreign_key_meta.resolved_target_column._meta.name - ) - extra["foreign_key"] = { - "to": tablename, - "target_column": target_column, - } - - if include_readable: - columns[f"{column_name}_readable"] = (str, None) - else: - # This is used to tell Piccolo Admin that we want to display these - # values using a specific widget. - if isinstance(column, Text): - extra["widget"] = "text-area" - elif isinstance(column, (JSON, JSONB)): - extra["widget"] = "json" - elif isinstance(column, Timestamptz): - extra["widget"] = "timestamptz" - - # It is useful for Piccolo API and Piccolo Admin to easily know - # how many dimensions the array has. - if isinstance(column, Array): - extra["dimensions"] = column._get_dimensions() - - field = pydantic.Field( - json_schema_extra={"extra": extra}, - **params, + entries, extra_validators = _process_column( + column=column, + model_name=model_name, + all_optional=all_optional, + deserialize_json=deserialize_json, + nested=nested, + include_readable=include_readable, + recursion_depth=recursion_depth, + max_recursion_depth=max_recursion_depth, + include_columns=include_columns, + exclude_columns=exclude_columns, + include_default_columns=include_default_columns, ) - - columns[column_name] = (_type, field) - - pydantic_config = ( - pydantic_config.copy() - if pydantic_config - else pydantic.config.ConfigDict() + for name, value in entries: + columns[name] = value + validators.update(extra_validators) + + pydantic_config = _build_pydantic_config( + pydantic_config=pydantic_config, + json_schema_extra=json_schema_extra, + table=table, ) - pydantic_config["arbitrary_types_allowed"] = True - - json_schema_extra_ = defaultdict(dict, **(json_schema_extra or {})) - json_schema_extra_["extra"]["help_text"] = table._meta.help_text - - pydantic_config["json_schema_extra"] = dict(json_schema_extra_) model = pydantic.create_model( model_name,