From c83bf72729372ae9804095166ec37654b341fbab Mon Sep 17 00:00:00 2001 From: Sebastian Stehle Date: Wed, 21 May 2025 17:23:18 +0200 Subject: [PATCH 1/3] Metadata --- .../MySqlBulkInsertProvider.cs | 4 +- .../MySqlDialectBuilder.cs | 8 +- .../PostgreSqlBulkInsertProvider.cs | 12 +- .../SqlServerBulkInsertProvider.cs | 7 +- .../SqlServerDialectBuilder.cs | 23 ++-- .../SqliteBulkInsertProvider.cs | 21 +-- .../Abstractions/IBulkInsertProvider.cs | 6 + .../BulkInsertProviderBase.cs | 125 +++++++----------- .../Dialect/SqlDialectBuilder.cs | 91 +++++-------- .../EnumerableDataReader.cs | 18 +-- .../Extensions/ConnectionInfo.cs | 7 + .../Extensions/DbContextExtensions.cs | 22 ++- .../Extensions/DbSetExtensions.cs | 14 +- .../Metadata/MetadataProvider.cs | 53 ++++++++ .../Metadata/PropertyAccessor.cs | 66 +++++++++ .../Metadata/PropertyMetadata.cs | 55 ++++++++ .../Metadata/TableMetadata.cs | 46 +++++++ ...henX.EntityFrameworkCore.BulkInsert.csproj | 8 +- .../PropertyAccessor.cs | 40 ------ .../LibComparatorSqlServer.cs | 1 - 20 files changed, 385 insertions(+), 242 deletions(-) create mode 100644 src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs create mode 100644 src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs create mode 100644 src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs create mode 100644 src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs create mode 100644 src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs delete mode 100644 src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs index a5431d1..38a95b3 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs @@ -4,6 +4,7 @@ using MySqlConnector; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.MySql; @@ -29,9 +30,10 @@ public MySqlBulkInsertProvider(ILogger? logger = null) protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk ) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs index 8a2301c..c00b4a9 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs @@ -1,9 +1,7 @@ using System.Text; -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; - using PhenX.EntityFrameworkCore.BulkInsert.Dialect; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.MySql; @@ -45,9 +43,9 @@ protected override void AppendOnConflictStatement(StringBuilder sql) sql.Append("ON DUPLICATE KEY"); } - protected override void AppendDoNothing(StringBuilder sql, IProperty[] insertedProperties) + protected override void AppendDoNothing(StringBuilder sql, IEnumerable insertedProperties) { - var columnName = insertedProperties[0].GetColumnName(); + var columnName = insertedProperties.First().ColumnName; sql.Append($"UPDATE {Quote(columnName)} = {GetExcludedColumnName(columnName)}"); } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs index d0d4c24..239b967 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs @@ -5,6 +5,7 @@ using Npgsql; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql; @@ -24,9 +25,9 @@ public PostgreSqlBulkInsertProvider(ILogger? logge /// protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD COLUMN {BulkInsertId} SERIAL PRIMARY KEY;"; - private string GetBinaryImportCommand(DbContext context, Type entityType, string tableName) + private static string GetBinaryImportCommand(TableMetadata tableInfo, string tableName) { - var columns = GetQuotedColumns(context, entityType, false); + var columns = tableInfo.GetProperties(false).Select(X => X.QuotedColumName); return $"COPY {tableName} ({string.Join(", ", columns)}) FROM STDIN (FORMAT BINARY)"; } @@ -35,15 +36,16 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, - CancellationToken ctk) where T : class + CancellationToken ctk) { var connection = (NpgsqlConnection)context.Database.GetDbConnection(); - var importCommand = GetBinaryImportCommand(context, typeof(T), tableName); + var importCommand = GetBinaryImportCommand(tableInfo, tableName); var writer = sync // ReSharper disable once MethodHasAsyncOverloadWithCancellation diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs index 3f9956b..7e6a0d1 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs @@ -5,6 +5,7 @@ using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Logging; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer; @@ -31,12 +32,12 @@ public SqlServerBulkInsertProvider(ILogger? logger protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, - CancellationToken ctk - ) + CancellationToken ctk) { var connection = (SqlConnection) context.Database.GetDbConnection(); var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction() as SqlTransaction; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs index 8e2df52..1a1e74a 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs @@ -1,10 +1,7 @@ -using System.Linq.Expressions; using System.Text; -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; - using PhenX.EntityFrameworkCore.BulkInsert.Dialect; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer; @@ -17,16 +14,18 @@ internal class SqlServerDialectBuilder : SqlDialectBuilder protected override bool SupportsMoveRows => false; - public override string BuildMoveDataSql(DbContext context, string source, + public override string BuildMoveDataSql( + TableMetadata source, string target, - IProperty[] insertedProperties, - IProperty[] properties, - BulkInsertOptions options, OnConflictOptions? onConflict = null) + IReadOnlyList insertedProperties, + IReadOnlyList properties, + BulkInsertOptions options, + OnConflictOptions? onConflict = null) { - var insertedColumns = insertedProperties.Select(p => Quote(p.GetColumnName())).ToArray(); + var insertedColumns = insertedProperties.Select(x => x.QuotedColumName); var insertedColumnList = string.Join(", ", insertedColumns); - var returnedColumns = properties.Select(p => $"INSERTED.{p.GetColumnName()} AS {p.GetColumnName()}"); + var returnedColumns = properties.Select(p => $"INSERTED.{p.ColumnName} AS {p.ColumnName}"); var columnList = string.Join(", ", returnedColumns); var q = new StringBuilder(); @@ -34,12 +33,12 @@ public override string BuildMoveDataSql(DbContext context, string source, // Merge handling if (onConflict is OnConflictOptions onConflictTyped && onConflictTyped.Match != null) { - var matchColumns = GetColumns(context, onConflictTyped.Match); + var matchColumns = GetColumns(source, onConflictTyped.Match); var matchOn = string.Join(" AND ", matchColumns.Select(col => $"TARGET.{col} = SOURCE.{col}")); var updateSet = onConflictTyped.Update != null - ? string.Join(", ", GetUpdates(context, insertedProperties, onConflictTyped.Update)) + ? string.Join(", ", GetUpdates(source, insertedProperties, onConflictTyped.Update)) : null; q.AppendLine($"MERGE INTO {target} AS TARGET"); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs index 2e3c0f7..3674f82 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs @@ -6,7 +6,7 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; -using PhenX.EntityFrameworkCore.BulkInsert.Extensions; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite; @@ -81,17 +81,17 @@ private static SqliteType GetSqliteType(Type clrType) throw new InvalidOperationException("Unknown Sqlite type for " + clrType); } - private DbCommand GetInsertCommand(DbContext context, Type entityType, string tableName, + private DbCommand GetInsertCommand(DbContext context, TableMetadata tableInfo, string tableName, int batchSize) { - var columns = context.GetProperties(entityType, false); + var columns = tableInfo.GetProperties(false); var cmd = context.Database.GetDbConnection().CreateCommand(); var sqliteColumns = columns .Select(c => new { - Name = c.GetColumnName(), - Type = GetSqliteType(c.GetProviderClrType() ?? c.ClrType) + Name = c.ColumnName, + Type = GetSqliteType(c.ProviderClrType ?? c.ClrType) }) .ToArray(); @@ -125,18 +125,19 @@ private DbCommand GetInsertCommand(DbContext context, Type entityType, string ta protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk ) where T : class { const int maxParams = 1000; var batchSize = options.BatchSize ?? 5; - batchSize = Math.Min(batchSize, maxParams / properties.Length); + batchSize = Math.Min(batchSize, maxParams / properties.Count); - await using var insertCommand = GetInsertCommand(context, typeof(T), tableName, batchSize); + await using var insertCommand = GetInsertCommand(context, tableInfo, tableName, batchSize); foreach (var chunk in entities.Chunk(batchSize)) { @@ -149,7 +150,7 @@ CancellationToken ctk // Last chunk else { - var partialInsertCommand = GetInsertCommand(context, typeof(T), tableName, chunk.Length); + var partialInsertCommand = GetInsertCommand(context, tableInfo, tableName, chunk.Length); FillValues(chunk, partialInsertCommand.Parameters, properties); await ExecuteCommand(sync, partialInsertCommand, ctk); @@ -170,7 +171,7 @@ private static async Task ExecuteCommand(bool sync, DbCommand insertCommand, Can } } - private static void FillValues(T[] chunk, DbParameterCollection parameters, PropertyAccessor[] properties) where T : class + private static void FillValues(T[] chunk, DbParameterCollection parameters, IReadOnlyList properties) where T : class { var index = 0; foreach (var entity in chunk) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs index 344cebd..2430fe3 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs @@ -1,5 +1,7 @@ using Microsoft.EntityFrameworkCore; +using PhenX.EntityFrameworkCore.BulkInsert.Dialect; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.Abstractions; @@ -15,6 +17,7 @@ internal interface IBulkInsertProvider internal Task> BulkInsertReturnEntities( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, @@ -27,9 +30,12 @@ internal Task> BulkInsertReturnEntities( internal Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, CancellationToken ctk = default ) where T : class; + + SqlDialectBuilder SqlDialect { get; } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index 675a7a1..bc913e7 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -1,49 +1,56 @@ using System.Data.Common; using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Logging; using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; using PhenX.EntityFrameworkCore.BulkInsert.Dialect; using PhenX.EntityFrameworkCore.BulkInsert.Extensions; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert; -internal abstract class BulkInsertProviderBase : IBulkInsertProvider +#pragma warning disable CS9113 // Parameter is unread. +internal abstract class BulkInsertProviderBase(ILogger>? logger = null) : IBulkInsertProvider +#pragma warning restore CS9113 // Parameter is unread. where TDialect : SqlDialectBuilder, new() { protected readonly TDialect SqlDialect = new(); - private readonly ILogger>? Logger; protected virtual string BulkInsertId => "_bulk_insert_id"; protected abstract string CreateTableCopySql { get; } protected abstract string AddTableCopyBulkInsertId { get; } - protected BulkInsertProviderBase(ILogger>? logger = null) - { - Logger = logger; - } + SqlDialectBuilder IBulkInsertProvider.SqlDialect => SqlDialect; protected async Task CreateTableCopyAsync( bool sync, DbContext context, + TableMetadata tableInfo, CancellationToken cancellationToken = default) where T : class { - var tableInfo = GetTableInfo(context, typeof(T)); - var tableName = QuoteTableName(tableInfo.SchemaName, tableInfo.TableName); - var tempTableName = QuoteTableName(null, GetTempTableName(tableInfo.TableName)); + var tempTableName = await CreateTemporaryTableAsync(sync, context, tableInfo, cancellationToken); - var keptColumns = string.Join(", ", GetQuotedColumns(context, typeof(T), false)); - var query = string.Format(CreateTableCopySql, tempTableName, tableName, keptColumns); + await AddBulkInsertIdColumn(sync, context, tempTableName, cancellationToken); - await ExecuteAsync(sync, context, query, cancellationToken); + return tempTableName; + } - await AddBulkInsertIdColumn(sync, context, tempTableName, cancellationToken); + private async Task CreateTemporaryTableAsync( + bool sync, + DbContext context, + TableMetadata tableInfo, + CancellationToken cancellationToken) + { + var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName)); + var tempColumns = string.Join(", ", tableInfo.GetProperties(false).Select(x => x.QuotedColumName)); + var query = string.Format(CreateTableCopySql, tempTableName, tableInfo.QuotedTableName, tempColumns); + + await ExecuteAsync(sync, context, query, cancellationToken); return tempTableName; } @@ -79,6 +86,7 @@ protected static async Task ExecuteAsync(bool sync, DbContext context, string qu public async Task> CopyFromTempTableAsync( bool sync, DbContext context, + TableMetadata tableInfo, string tempTableName, bool returnData, BulkInsertOptions options, @@ -88,6 +96,7 @@ public async Task> CopyFromTempTableAsync( return await CopyFromTempTableWithoutKeysAsync( sync, context, + tableInfo, tempTableName, returnData, options, @@ -98,6 +107,7 @@ public async Task> CopyFromTempTableAsync( private async Task> CopyFromTempTableWithoutKeysAsync( bool sync, DbContext context, + TableMetadata tableInfo, string tempTableName, bool returnData, BulkInsertOptions options, @@ -106,17 +116,15 @@ private async Task> CopyFromTempTableWithoutKeysAsync( where T : class where TResult : class { - var (schemaName, tableName, _) = GetTableInfo(context, typeof(T)); - var quotedTableName = QuoteTableName(schemaName, tableName); - var movedProperties = context.GetProperties(typeof(T), options.CopyGeneratedColumns); - var returnedProperties = returnData ? context.GetProperties(typeof(T)) : []; + var movedProperties = tableInfo.GetProperties(options.CopyGeneratedColumns); + var returnedProperties = returnData ? tableInfo.GetProperties() : []; if (returnData && !SqlDialect.SupportsReturning) { throw new NotSupportedException("Provider does not support returning entities."); } - var query = SqlDialect.BuildMoveDataSql(context, tempTableName, quotedTableName, movedProperties, returnedProperties, options, onConflict); + var query = SqlDialect.BuildMoveDataSql(tableInfo, tempTableName, movedProperties, returnedProperties, options, onConflict); if (returnData) { @@ -146,26 +154,28 @@ static async Task> QueryAsync(bool sync, DbContext context, string public async Task> BulkInsertReturnEntities( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, CancellationToken ctk = default ) where T : class { - var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk); + var connectionInfo = await context.GetConnection(sync, ctk); - var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk); + var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - var result = await CopyFromTempTableAsync(sync, context, tableName, true, options, onConflict, cancellationToken: ctk); + var result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk); - await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk); + await Finish(sync, connectionInfo, ctk); return result; } - private static async Task Finish(bool sync, DbConnection connection, bool wasClosed, - IDbContextTransaction transaction, bool wasBegan, CancellationToken ctk) + private static async Task Finish(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk) { + var (connection, wasClosed, transaction, wasBegan) = connectionInfo; + if (!wasBegan) { if (sync) @@ -196,6 +206,7 @@ private static async Task Finish(bool sync, DbConnection connection, bool wasClo public async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, @@ -204,23 +215,24 @@ public async Task BulkInsert( { if (onConflict != null) { - var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk); + var connectionInfo = await context.GetConnection(sync, ctk); - var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk); + var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - await CopyFromTempTableAsync(sync, context, tableName, false, options, onConflict, ctk); + await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); - await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk); + await Finish(sync, connectionInfo, ctk); } else { - await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: false, ctk: ctk); + await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: false, ctk: ctk); } } private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, bool tempTableRequired, @@ -231,22 +243,19 @@ public async Task BulkInsert( throw new InvalidOperationException("No entities to insert."); } - var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk); + var connectionInfo = await context.GetConnection(sync, ctk); var tableName = tempTableRequired - ? await CreateTableCopyAsync(sync, context, ctk) - : GetQuotedTableName(context, typeof(T)); + ? await CreateTableCopyAsync(sync, context, tableInfo, ctk) + : tableInfo.QuotedTableName; - var properties = context - .GetProperties(typeof(T), options.CopyGeneratedColumns) - .Select(p => new PropertyAccessor(p)) - .ToArray(); + var properties = tableInfo.GetProperties(options.CopyGeneratedColumns); - await BulkInsert(false, context, entities, tableName, properties, options, ctk); + await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk); - await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk); + await Finish(sync, connectionInfo, ctk); - return (tableName, connection); + return (tableName, connectionInfo.Connection); } /// @@ -255,43 +264,11 @@ public async Task BulkInsert( protected abstract Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk ) where T : class; - - /// - /// Get table information for the given entity type : schema name, table name and primary key. - /// - public static (string? SchemaName, string TableName, IKey PrimaryKey) GetTableInfo(DbContext context, Type entityType) - { - var entityTypeInfo = context.Model.FindEntityType(entityType); - var schema = (entityTypeInfo ?? throw new InvalidOperationException($"Could not determine entity type for type {entityType.Name}")).GetSchema(); - var tableName = entityTypeInfo.GetTableName(); - - if (string.IsNullOrWhiteSpace(tableName)) - { - throw new InvalidOperationException($"Could not determine table name for type {entityType.Name}"); - } - - return (schema, tableName, entityTypeInfo.FindPrimaryKey()!); - } - - protected string GetQuotedTableName(DbContext context, Type entityType) - { - var (schema, tableName, _) = GetTableInfo(context, entityType); - - return QuoteTableName(schema, tableName); - } - - protected string QuoteTableName(string? schema, string table) => SqlDialect.QuoteTableName(schema, table); - - protected string[] GetQuotedColumns(DbContext context, Type entityType, bool includeGenerated = true) - { - return context.GetProperties(entityType, includeGenerated) - .Select(p => Quote(p.GetColumnName())) - .ToArray(); - } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs index c43dc1e..2a0a28e 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs @@ -1,10 +1,7 @@ using System.Linq.Expressions; using System.Text; -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Internal; - +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.Dialect; @@ -18,36 +15,10 @@ internal abstract class SqlDialectBuilder protected virtual bool SupportsMoveRows => true; public virtual bool SupportsReturning => true; - /// - /// Gets the name of the column for a property in a given entity type. - /// - /// The DbContext - /// The property name - /// The entity type - /// The column name - /// Thrown when the entity type or property is not found. - protected string GetColumnName(DbContext context, string propName) - { - var entityType = context.Model.FindEntityType(typeof(TEntity)); - if (entityType == null) - { - throw new InvalidOperationException($"Entity type {typeof(TEntity).Name} not found in the model."); - } - - var property = entityType.FindProperty(propName); - if (property == null) - { - throw new InvalidOperationException($"Property {propName} not found in entity type {typeof(TEntity).Name}."); - } - - return Quote(property.GetColumnName()); - } - /// /// Builds the SQL for moving data from one table to another. /// - /// The DbContext - /// Source table name + /// Source table /// Target table name /// Properties to be copied /// Properties to be returned @@ -55,20 +26,22 @@ protected string GetColumnName(DbContext context, string propName) /// On conflict options /// Entity type /// The SQL query - public virtual string BuildMoveDataSql(DbContext context, string source, + public virtual string BuildMoveDataSql( + TableMetadata source, string target, - IProperty[] insertedProperties, - IProperty[] properties, + IReadOnlyList insertedProperties, + IReadOnlyList properties, BulkInsertOptions options, OnConflictOptions? onConflict = null) { - var insertedColumns = insertedProperties.Select(p => Quote(p.GetColumnName())); + var insertedColumns = insertedProperties.Select(p => p.QuotedColumName); var insertedColumnList = string.Join(", ", insertedColumns); - var returnedColumns = properties.Select(p => Quote(p.GetColumnName())); + var returnedColumns = properties.Select(p => p.QuotedColumName); var columnList = string.Join(", ", returnedColumns); var q = new StringBuilder(); + var sourceName = source.QuotedTableName; if (SupportsMoveRows && options.MoveRows) { q.AppendLine($""" @@ -77,13 +50,13 @@ DELETE FROM {source} RETURNING {insertedColumnList} ) """); - source = "moved_rows"; + sourceName = "moved_rows"; } q.AppendLine($""" INSERT INTO {target} ({insertedColumnList}) SELECT {insertedColumnList} - FROM {source} + FROM {sourceName} WHERE TRUE """); @@ -96,13 +69,13 @@ WHERE TRUE if (onConflictTyped.Match != null) { q.Append(' '); - AppendConflictMatch(q, GetColumns(context, onConflictTyped.Match)); + AppendConflictMatch(q, GetColumns(source, onConflictTyped.Match)); } if (onConflictTyped.Update != null) { q.Append(' '); - AppendOnConflictUpdate(q, GetUpdates(context, insertedProperties, onConflictTyped.Update)); + AppendOnConflictUpdate(q, GetUpdates(source, insertedProperties, onConflictTyped.Update)); } if (onConflictTyped.Condition != null) @@ -128,7 +101,7 @@ WHERE TRUE return q.ToString(); } - protected virtual void AppendDoNothing(StringBuilder sql, IProperty[] insertedProperties) + protected virtual void AppendDoNothing(StringBuilder sql, IEnumerable insertedProperties) { sql.AppendLine("DO NOTHING"); } @@ -205,15 +178,15 @@ public string QuoteTableName(string? schema, string tableName) /// /// Gets column names for the insert statement, from an object initializer. /// - protected string[] GetColumns(DbContext context, Expression> columns) + protected string[] GetColumns(TableMetadata table, Expression> columns) { return columns.Body switch { NewExpression newExpression => newExpression.Arguments.OfType() - .Select(m => GetColumnName(context, m.Member.Name)) + .Select(m => table.GetQuotedColumnName(m.Member.Name)) .ToArray(), MemberExpression memberExpression => [ - GetColumnName(context, memberExpression.Member.Name) + table.GetQuotedColumnName(memberExpression.Member.Name) ], _ => throw new NotSupportedException("Unsupported expression type") }; @@ -230,7 +203,7 @@ protected string[] GetColumns(DbContext context, Expression> /// var updates = GetUpdates(context, e => e.Prop1); /// /// - protected IEnumerable GetUpdates(DbContext context, IProperty[] properties, Expression> update) + protected IEnumerable GetUpdates(TableMetadata table, IEnumerable properties, Expression> update) { switch (update.Body) { @@ -238,7 +211,7 @@ protected IEnumerable GetUpdates(DbContext context, IProperty[] prope { foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member))) { - yield return $"{GetColumnName(context, arg.member.Name)} = {ToSqlExpression(context, arg.expr)}"; + yield return $"{table.GetColumnName(arg.member.Name)} = {ToSqlExpression(table, arg.expr)}"; } break; @@ -247,20 +220,18 @@ protected IEnumerable GetUpdates(DbContext context, IProperty[] prope { foreach (var binding in memberInit.Bindings.OfType()) { - yield return $"{GetColumnName(context, binding.Member.Name)} = {ToSqlExpression(context, binding.Expression)}"; + yield return $"{table.GetColumnName(binding.Member.Name)} = {ToSqlExpression(table, binding.Expression)}"; } break; } case MemberExpression memberExpr: - yield return $"{GetColumnName(context, memberExpr.Member.Name)} = {ToSqlExpression(context, memberExpr)}"; + yield return $"{table.GetColumnName(memberExpr.Member.Name)} = {ToSqlExpression(table, memberExpr)}"; break; case ParameterExpression parameterExpr when (parameterExpr.Type == typeof(T)): foreach (var property in properties) { - var columName = property.GetColumnName(); - - yield return $"{Quote(columName)} = {GetExcludedColumnName(columName)}"; + yield return $"{property.QuotedColumName} = {GetExcludedColumnName(property.ColumnName)}"; } break; @@ -273,21 +244,21 @@ protected IEnumerable GetUpdates(DbContext context, IProperty[] prope /// /// Converts an expression to an SQL string. /// - /// The DbContext + /// The DbContext /// The expression, with simple operations /// Entity type /// An SQL statement /// Thrown when an expression could not be translated. - private string ToSqlExpression(DbContext context, Expression expr) + private string ToSqlExpression(TableMetadata table, Expression expr) { switch (expr) { case MemberExpression m: - return GetExcludedColumnName(GetColumnName(context, m.Member.Name)); + return GetExcludedColumnName(table.GetColumnName(m.Member.Name)); case BinaryExpression b: - var left = ToSqlExpression(context, b.Left); - var right = ToSqlExpression(context, b.Right); + var left = ToSqlExpression(table, b.Left); + var right = ToSqlExpression(table, b.Right); var op = b.NodeType switch { ExpressionType.Add => b.Type == typeof(string) ? ConcatOperator : "+", @@ -329,18 +300,18 @@ private string ToSqlExpression(DbContext context, Expression expr) case UnaryExpression u: if (u.NodeType == ExpressionType.Convert) { - return ToSqlExpression(context, u.Operand); + return ToSqlExpression(table, u.Operand); } if (u.NodeType == ExpressionType.Not) { - return $"NOT ({ToSqlExpression(context, u.Operand)})"; + return $"NOT ({ToSqlExpression(table, u.Operand)})"; } throw new NotSupportedException($"Unary operator not supported: {u.NodeType}"); case MethodCallExpression mce: // Supporte quelques méthodes courantes (ToLower, ToUpper, Trim, etc.) - var objSql = mce.Object != null ? ToSqlExpression(context, mce.Object) : null; - var argsSql = mce.Arguments.Select(expr1 => ToSqlExpression(context, expr1)).ToArray(); + var objSql = mce.Object != null ? ToSqlExpression(table, mce.Object) : null; + var argsSql = mce.Arguments.Select(expr1 => ToSqlExpression(table, expr1)).ToArray(); switch (mce.Method.Name) { case "ToLower": diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs index 00f8241..ec07fe0 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs @@ -1,14 +1,16 @@ using System.Data; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; + namespace PhenX.EntityFrameworkCore.BulkInsert; internal class EnumerableDataReader : IDataReader { private readonly IEnumerator _enumerator; - private readonly PropertyAccessor[] _properties; + private readonly IReadOnlyList _properties; private readonly Dictionary _ordinalMap; - public EnumerableDataReader(IEnumerable rows, PropertyAccessor[] properties) + public EnumerableDataReader(IEnumerable rows, IReadOnlyList properties) { _enumerator = rows.GetEnumerator(); _properties = properties; @@ -32,7 +34,7 @@ public virtual object GetValue(int i) return DBNull.Value; } - return _properties[i].GetValue(current); + return _properties[i].GetValue(current)!; } public int GetValues(object[] values) @@ -43,18 +45,18 @@ public int GetValues(object[] values) return 0; } - for (var i = 0; i < _properties.Length; i++) + for (var i = 0; i < _properties.Count; i++) { - values[i] = _properties[i].GetValue(current); + values[i] = _properties[i].GetValue(current)!; } - return _properties.Length; + return _properties.Count; } public bool Read() => _enumerator.MoveNext(); - public int FieldCount => _properties.Length; - public Type GetFieldType(int i) => _properties[i].ProviderClrType; + public int FieldCount => _properties.Count; + public Type GetFieldType(int i) => _properties[i].ClrType; public int GetOrdinal(string name) => _ordinalMap.GetValueOrDefault(name, -1); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs new file mode 100644 index 0000000..2d8df8b --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs @@ -0,0 +1,7 @@ +using System.Data.Common; + +using Microsoft.EntityFrameworkCore.Storage; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions; + +internal readonly record struct ConnectionInfo(DbConnection Connection, bool WasClosed, IDbContextTransaction Transaction, bool WasBegan); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs index 666b120..b019949 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs @@ -1,28 +1,22 @@ using System.Data; -using System.Data.Common; using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Infrastructure; + +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions; internal static class DbContextExtensions { - /// - /// Gets cached properties for an entity type, using reflection if not already cached. - /// - internal static IProperty[] GetProperties(this DbContext context, Type entityType, bool includeGenerated = true) + public static TableMetadata GetTableInfo(this DbContext context) { - var entityTypeInfo = context.Model.FindEntityType(entityType) ?? throw new InvalidOperationException($"Could not determine entity type for type {entityType.Name}"); + var provider = context.GetService(); - return entityTypeInfo - .GetProperties() - .Where(p => !p.IsShadowProperty() && (includeGenerated || p.ValueGenerated != ValueGenerated.OnAdd)) - .ToArray(); + return provider.GetTableInfo(context); } - internal static async Task<(DbConnection connection, bool wasClosed, IDbContextTransaction transaction, bool wasBegan)> GetConnection( + internal static async Task GetConnection( this DbContext context, bool sync, CancellationToken ctk = default) { var connection = context.Database.GetDbConnection(); @@ -59,6 +53,6 @@ internal static IProperty[] GetProperties(this DbContext context, Type entityTyp } } - return (connection, wasClosed, transaction, wasBegan); + return new ConnectionInfo(connection, wasClosed, transaction, wasBegan); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs index 4d76cee..570c085 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs @@ -1,4 +1,4 @@ -using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; @@ -23,8 +23,9 @@ public static async Task> ExecuteBulkInsertReturnEntitiesAsync( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - return await provider.BulkInsertReturnEntities(false, context, entities, options, onConflict, ctk); + return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk); } /// @@ -53,8 +54,9 @@ public static async Task ExecuteBulkInsertAsync( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - await provider.BulkInsert(false, context, entities, options, onConflict, ctk); + await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk); } /// @@ -82,8 +84,9 @@ public static List ExecuteBulkInsertReturnEntities( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - return provider.BulkInsertReturnEntities(true, context, entities, options, onConflict).GetAwaiter().GetResult(); + return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); } /// @@ -116,8 +119,9 @@ public static void ExecuteBulkInsert( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - provider.BulkInsert(true, context, entities, options, onConflict).GetAwaiter().GetResult(); + provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); } /// diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs new file mode 100644 index 0000000..fcac7dd --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs @@ -0,0 +1,53 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; + +using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal static class MetadataProvider where T : DbContext +{ + public static readonly MetadataProvider Instance = new(); +} + +internal sealed class MetadataProvider +{ + private Dictionary? _tables; + + public TableMetadata GetTableInfo(DbContext context) + { + var tables = GetTables(context); + + if (!tables.TryGetValue(typeof(T), out var table)) + { + throw new InvalidOperationException($"Cannot find metadata for type '{typeof(T)}'."); + } + + return table; + } + + private Dictionary GetTables(DbContext context) + { + if (_tables != null) + { + return _tables; + } + + lock (this) + { + if (_tables != null) + { + return _tables; + } + + var provider = context.GetService(); + + _tables = + context.Model.GetEntityTypes() + .ToDictionary( + x => x.ClrType, + x => new TableMetadata(x, provider.SqlDialect)); + return _tables; + } + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs new file mode 100644 index 0000000..b731b5f --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs @@ -0,0 +1,66 @@ +using System.Reflection.Emit; +using System.Reflection; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal static class PropertyAccessor +{ + public delegate TValue Getter(TSource source); + + public static Getter CreateUntypedGetter(PropertyInfo propertyInfo, Type sourceType, Type valueType) + { + var method = + typeof(PropertyAccessor).GetMethod(nameof(CreateInternalUntypedGetter), BindingFlags.NonPublic)! + .MakeGenericMethod(sourceType, valueType); + + return (Getter)method.Invoke(null, [propertyInfo])!; + } + + private static Getter CreateInternalUntypedGetter(PropertyInfo propertyInfo) + { + var getter = CreateGetter(propertyInfo); + + return source => getter((TSource)source!); + } + + public static Getter CreateGetter(PropertyInfo propertyInfo) + { + if (!propertyInfo.CanRead) + { + return x => throw new NotSupportedException(); + } + + var bakingField = + propertyInfo.DeclaringType!.GetField($"<{propertyInfo.Name}>k__BackingField", + BindingFlags.NonPublic | + BindingFlags.Instance); + + var propertyGetMethod = propertyInfo.GetGetMethod()!; + + var getMethod = new DynamicMethod(propertyGetMethod.Name, typeof(TValue), [typeof(TSource)], true); + var getGenerator = getMethod.GetILGenerator(); + + // Load this to stack. + getGenerator.Emit(OpCodes.Ldarg_0); + + if (bakingField != null && !propertyGetMethod.IsVirtual) + { + // Get field directly. + getGenerator.Emit(OpCodes.Ldfld, bakingField); + } + else if (propertyGetMethod.IsVirtual) + { + // Call the virtual property. + getGenerator.Emit(OpCodes.Callvirt, propertyGetMethod); + } + else + { + // Call the non virtual property. + getGenerator.Emit(OpCodes.Call, propertyGetMethod); + } + + getGenerator.Emit(OpCodes.Ret); + + return getMethod.CreateDelegate>(); + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs new file mode 100644 index 0000000..b50993d --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs @@ -0,0 +1,55 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +using PhenX.EntityFrameworkCore.BulkInsert.Dialect; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal sealed class PropertyMetadata(IProperty property, SqlDialectBuilder dialect) +{ + private readonly PropertyAccessor.Getter _getter = BuildGetter(property); + + public string Name { get; } = property.Name; + + public string ColumnName { get; } = property.GetColumnName(); + + public string QuotedColumName { get; } = dialect.Quote(property.GetColumnName()); + + public Type ClrType { get; } = property.ClrType; + + public Type? ProviderClrType { get; } = property.GetProviderClrType(); + + public bool IsGenerated { get; } = property.ValueGenerated == ValueGenerated.OnAdd; + + public object? GetValue(object entity) + { + return _getter(entity!); + } + + private static PropertyAccessor.Getter BuildGetter(IProperty property) + { + var valueConverter = + property.GetValueConverter() ?? + property.GetTypeMapping().Converter; + + var actualGetter = + PropertyAccessor.CreateUntypedGetter( + property.PropertyInfo!, + property.DeclaringType.ClrType, + property.ClrType); + + if (valueConverter != null) + { + var converter = valueConverter.ConvertToProvider; + var original = actualGetter; + actualGetter = source => + { + var value = original(source); + + return converter(value); + }; + } + + return actualGetter; + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs new file mode 100644 index 0000000..57b80a4 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs @@ -0,0 +1,46 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +using PhenX.EntityFrameworkCore.BulkInsert.Dialect; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal sealed class TableMetadata(IEntityType entityType, SqlDialectBuilder dialect) +{ + private IReadOnlyList? _notGeneratedProperties; + + public string TableName { get; } = + entityType.GetTableName() ?? throw new InvalidOperationException("Canot determine table name."); + + public string QuotedTableName { get; } = + dialect.QuoteTableName(entityType.GetSchema(), entityType.GetTableName()!); + + public IReadOnlyList Properties { get; } = + entityType.GetProperties().Where(p => !p.IsShadowProperty()).Select(x => new PropertyMetadata(x, dialect)).ToList(); + + public IReadOnlyList GetProperties(bool includeGenerated = false) + { + if (includeGenerated) + { + return Properties; + } + + return _notGeneratedProperties ??= Properties.Where(x => !x.IsGenerated).ToList(); + } + + public string GetQuotedColumnName(string propertyName) + { + var property = Properties.FirstOrDefault(x => x.Name == propertyName) + ?? throw new InvalidOperationException($"Property {propertyName} not found in entity type {entityType.Name}."); + + return property.QuotedColumName; + } + + public string GetColumnName(string propertyName) + { + var property = Properties.FirstOrDefault(x => x.Name == propertyName) + ?? throw new InvalidOperationException($"Property {propertyName} not found in entity type {entityType.Name}."); + + return property.ColumnName; + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj b/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj index b59ec6a..17da02a 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj @@ -6,10 +6,10 @@ - - - - + + + + diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs deleted file mode 100644 index 992dab1..0000000 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs +++ /dev/null @@ -1,40 +0,0 @@ -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; - -namespace PhenX.EntityFrameworkCore.BulkInsert; - -internal readonly struct PropertyAccessor -{ - private Func ValueGetter { get; } - - private IProperty Property { get; } - - public Type ProviderClrType { get; } - - public string Name => Property.Name; - - public string ColumnName => Property.GetColumnName(); - - public PropertyAccessor(IProperty property) - { - Property = property; - - var propInfo = property.PropertyInfo!; - - var valueConverter = property.GetValueConverter()?? - property.GetTypeMapping().Converter; - - if (valueConverter != null) - { - var conv = valueConverter.ConvertToProvider; - ValueGetter = v => conv(propInfo.GetValue(v)); - ProviderClrType = valueConverter.ProviderClrType; - return; - } - - ValueGetter = propInfo.GetValue; - ProviderClrType = Nullable.GetUnderlyingType(propInfo.PropertyType) ?? propInfo.PropertyType; - } - - public object GetValue(object entity) => ValueGetter(entity) ?? DBNull.Value; -} diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs index ad7e843..508fc70 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs @@ -5,7 +5,6 @@ using Microsoft.EntityFrameworkCore; -using PhenX.EntityFrameworkCore.BulkInsert.Sqlite; using PhenX.EntityFrameworkCore.BulkInsert.SqlServer; using Testcontainers.MsSql; From 7485d81aeef0d83ad30bdd1bd293f8ae5185dabc Mon Sep 17 00:00:00 2001 From: Sebastian Stehle Date: Wed, 21 May 2025 23:49:16 +0200 Subject: [PATCH 2/3] Metadata. --- .../PostgreSqlDbContextOptionsExtensions.cs | 1 - .../SqlServerDbContextOptionsExtensions.cs | 1 - .../SqlServerDialectBuilder.cs | 16 ++--- .../SqliteDbContextOptionsExtensions.cs | 1 - .../BulkInsertProviderBase.cs | 71 ++++++++++++------- .../Dialect/SqlDialectBuilder.cs | 17 +++-- 6 files changed, 61 insertions(+), 46 deletions(-) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs index 480d6df..a13d4df 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs @@ -1,5 +1,4 @@ using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; using PhenX.EntityFrameworkCore.BulkInsert.Extensions; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs index 55e264f..d7222f2 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs @@ -1,5 +1,4 @@ using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; using PhenX.EntityFrameworkCore.BulkInsert.Extensions; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs index 329a47a..90569b2 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs @@ -15,8 +15,8 @@ internal class SqlServerDialectBuilder : SqlDialectBuilder protected override bool SupportsMoveRows => false; public override string BuildMoveDataSql( - TableMetadata source, - string target, + TableMetadata target, + string source, IReadOnlyList insertedProperties, IReadOnlyList properties, BulkInsertOptions options, @@ -32,21 +32,21 @@ public override string BuildMoveDataSql( if (options.CopyGeneratedColumns) { - q.AppendLine($"SET IDENTITY_INSERT {target} ON;"); + q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} ON;"); } // Merge handling if (onConflict is OnConflictOptions onConflictTyped && onConflictTyped.Match != null) { - var matchColumns = GetColumns(source, onConflictTyped.Match); + var matchColumns = GetColumns(target, onConflictTyped.Match); var matchOn = string.Join(" AND ", matchColumns.Select(col => $"TARGET.{col} = SOURCE.{col}")); var updateSet = onConflictTyped.Update != null - ? string.Join(", ", GetUpdates(source, insertedProperties, onConflictTyped.Update)) + ? string.Join(", ", GetUpdates(target, insertedProperties, onConflictTyped.Update)) : null; - q.AppendLine($"MERGE INTO {target} AS TARGET"); + q.AppendLine($"MERGE INTO {target.QuotedTableName} AS TARGET"); q.AppendLine( $"USING (SELECT {string.Join(", ", insertedColumns)} FROM {source}) AS SOURCE ({insertedColumnList})"); q.AppendLine($"ON {matchOn}"); @@ -68,7 +68,7 @@ public override string BuildMoveDataSql( // No conflict handling else { - q.AppendLine($"INSERT INTO {target} ({insertedColumnList})"); + q.AppendLine($"INSERT INTO {target.QuotedTableName} ({insertedColumnList})"); if (columnList.Length != 0) { @@ -85,7 +85,7 @@ public override string BuildMoveDataSql( if (options.CopyGeneratedColumns) { - q.AppendLine($"SET IDENTITY_INSERT {target} OFF;"); + q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} OFF;"); } return q.ToString(); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs index 28df65f..87f3b74 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs @@ -1,5 +1,4 @@ using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; using PhenX.EntityFrameworkCore.BulkInsert.Extensions; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index 9816029..28a6433 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -29,28 +29,18 @@ internal abstract class BulkInsertProviderBase(ILogger CreateTableCopyAsync( bool sync, DbContext context, + BulkInsertOptions options, TableMetadata tableInfo, CancellationToken cancellationToken = default) where T : class - { - var tempTableName = await CreateTemporaryTableAsync(sync, context, tableInfo, cancellationToken); - - await AddBulkInsertIdColumn(sync, context, tempTableName, cancellationToken); - - return tempTableName; - } - - private async Task CreateTemporaryTableAsync( - bool sync, - DbContext context, - TableMetadata tableInfo, - CancellationToken cancellationToken) { var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName)); - var tempColumns = string.Join(", ", tableInfo.GetProperties(false).Select(x => x.QuotedColumName)); + var tempColumns = string.Join(", ", tableInfo.GetProperties(options.CopyGeneratedColumns).Select(x => x.QuotedColumName)); var query = string.Format(CreateTableCopySql, tempTableName, tableInfo.QuotedTableName, tempColumns); await ExecuteAsync(sync, context, query, cancellationToken); + await AddBulkInsertIdColumn(sync, context, tempTableName, cancellationToken); + return tempTableName; } @@ -156,17 +146,43 @@ public virtual async Task> BulkInsertReturnEntities( CancellationToken ctk = default ) where T : class { - var connectionInfo = await context.GetConnection(sync, ctk); + List result; - var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); + var connectionInfo = await context.GetConnection(sync, ctk); + try + { + var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - var result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk); + result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk); - await Finish(sync, connectionInfo, ctk); + await Commit(sync, connectionInfo, ctk); + } + finally + { + await Finish(sync, connectionInfo, ctk); + } return result; } + private static async Task Commit(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk) + { + var (_, _, transaction, wasBegan) = connectionInfo; + + if (!wasBegan) + { + if (sync) + { + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + transaction.Commit(); + } + else + { + await transaction.CommitAsync(ctk); + } + } + } + private static async Task Finish(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk) { var (connection, wasClosed, transaction, wasBegan) = connectionInfo; @@ -176,12 +192,10 @@ private static async Task Finish(bool sync, ConnectionInfo connectionInfo, Cance if (sync) { // ReSharper disable once MethodHasAsyncOverloadWithCancellation - transaction.Commit(); transaction.Dispose(); } else { - await transaction.CommitAsync(ctk); await transaction.DisposeAsync(); } } @@ -213,12 +227,17 @@ public virtual async Task BulkInsert( if (onConflict != null) { var connectionInfo = await context.GetConnection(sync, ctk); + try + { + var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - - await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); - - await Finish(sync, connectionInfo, ctk); + await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); + await Commit(sync, connectionInfo, ctk); + } + finally + { + await Finish(sync, connectionInfo, ctk); + } } else { @@ -243,7 +262,7 @@ public virtual async Task BulkInsert( var connectionInfo = await context.GetConnection(sync, ctk); var tableName = tempTableRequired - ? await CreateTableCopyAsync(sync, context, tableInfo, ctk) + ? await CreateTableCopyAsync(sync, context, options, tableInfo, ctk) : tableInfo.QuotedTableName; var properties = tableInfo.GetProperties(options.CopyGeneratedColumns); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs index 21250c2..72ef576 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs @@ -26,8 +26,8 @@ internal abstract class SqlDialectBuilder /// Entity type /// The SQL query public virtual string BuildMoveDataSql( - TableMetadata source, - string target, + TableMetadata target, + string source, IReadOnlyList insertedProperties, IReadOnlyList properties, BulkInsertOptions options, OnConflictOptions? onConflict = null) @@ -40,22 +40,21 @@ public virtual string BuildMoveDataSql( var q = new StringBuilder(); - var sourceName = source.QuotedTableName; if (SupportsMoveRows && options.MoveRows) { q.AppendLine($""" WITH moved_rows AS ( - DELETE FROM {source.QuotedTableName} + DELETE FROM {source} RETURNING {insertedColumnList} ) """); - sourceName = "moved_rows"; + source = "moved_rows"; } q.AppendLine($""" - INSERT INTO {target} ({insertedColumnList}) + INSERT INTO {target.QuotedTableName} ({insertedColumnList}) SELECT {insertedColumnList} - FROM {sourceName} + FROM {source} WHERE TRUE """); @@ -68,13 +67,13 @@ WHERE TRUE if (onConflictTyped.Match != null) { q.Append(' '); - AppendConflictMatch(q, GetColumns(source, onConflictTyped.Match)); + AppendConflictMatch(q, GetColumns(target, onConflictTyped.Match)); } if (onConflictTyped.Update != null) { q.Append(' '); - AppendOnConflictUpdate(q, GetUpdates(source, insertedProperties, onConflictTyped.Update)); + AppendOnConflictUpdate(q, GetUpdates(target, insertedProperties, onConflictTyped.Update)); } if (onConflictTyped.Condition != null) From 656459aaf9d2f66b4d243ebc362ea26767dbc8a3 Mon Sep 17 00:00:00 2001 From: Sebastian Stehle Date: Wed, 21 May 2025 23:53:22 +0200 Subject: [PATCH 3/3] Fix commit. --- .../BulkInsertProviderBase.cs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index 28a6433..780d4e2 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -155,6 +155,7 @@ public virtual async Task> BulkInsertReturnEntities( result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk); + // Commit the transaction if we own them. await Commit(sync, connectionInfo, ctk); } finally @@ -232,6 +233,8 @@ public virtual async Task BulkInsert( var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); + + // Commit the transaction if we own them. await Commit(sync, connectionInfo, ctk); } finally @@ -267,9 +270,17 @@ public virtual async Task BulkInsert( var properties = tableInfo.GetProperties(options.CopyGeneratedColumns); - await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk); + try + { + await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk); - await Finish(sync, connectionInfo, ctk); + // Commit the transaction if we own them. + await Commit(sync, connectionInfo, ctk); + } + finally + { + await Finish(sync, connectionInfo, ctk); + } return (tableName, connectionInfo.Connection); }