diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs index d6eae56..bc72001 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs @@ -4,6 +4,7 @@ using MySqlConnector; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.MySql; @@ -29,6 +30,7 @@ public MySqlBulkInsertProvider(ILogger? logger = null) public override Task> BulkInsertReturnEntities( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, @@ -41,9 +43,10 @@ public override Task> BulkInsertReturnEntities( protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk ) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDbContextOptionsExtensions.cs index 9a51ea9..cdc09fb 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDbContextOptionsExtensions.cs @@ -1,5 +1,6 @@ using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; + +using PhenX.EntityFrameworkCore.BulkInsert.Extensions; namespace PhenX.EntityFrameworkCore.BulkInsert.MySql; @@ -13,10 +14,6 @@ public static class MySqlDbContextOptionsExtensions /// public static DbContextOptionsBuilder UseBulkInsertMySql(this DbContextOptionsBuilder optionsBuilder) { - var extension = optionsBuilder.Options.FindExtension>() ?? new BulkInsertOptionsExtension(); - - ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension); - - return optionsBuilder; + return optionsBuilder.UseProvider(); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs index b4c5fac..9ada13b 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs @@ -1,9 +1,7 @@ using System.Text; -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; - using PhenX.EntityFrameworkCore.BulkInsert.Dialect; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.MySql; @@ -43,9 +41,9 @@ protected override void AppendOnConflictStatement(StringBuilder sql) sql.Append("ON DUPLICATE KEY"); } - protected override void AppendDoNothing(StringBuilder sql, IProperty[] insertedProperties) + protected override void AppendDoNothing(StringBuilder sql, IEnumerable insertedProperties) { - var columnName = insertedProperties[0].GetColumnName(); + var columnName = insertedProperties.First().ColumnName; sql.Append($"UPDATE {Quote(columnName)} = {GetExcludedColumnName(columnName)}"); } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs index d0d4c24..239b967 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs @@ -5,6 +5,7 @@ using Npgsql; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql; @@ -24,9 +25,9 @@ public PostgreSqlBulkInsertProvider(ILogger? logge /// protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD COLUMN {BulkInsertId} SERIAL PRIMARY KEY;"; - private string GetBinaryImportCommand(DbContext context, Type entityType, string tableName) + private static string GetBinaryImportCommand(TableMetadata tableInfo, string tableName) { - var columns = GetQuotedColumns(context, entityType, false); + var columns = tableInfo.GetProperties(false).Select(X => X.QuotedColumName); return $"COPY {tableName} ({string.Join(", ", columns)}) FROM STDIN (FORMAT BINARY)"; } @@ -35,15 +36,16 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, - CancellationToken ctk) where T : class + CancellationToken ctk) { var connection = (NpgsqlConnection)context.Database.GetDbConnection(); - var importCommand = GetBinaryImportCommand(context, typeof(T), tableName); + var importCommand = GetBinaryImportCommand(tableInfo, tableName); var writer = sync // ReSharper disable once MethodHasAsyncOverloadWithCancellation diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs index 288e687..a13d4df 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs @@ -1,5 +1,6 @@ -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore; + +using PhenX.EntityFrameworkCore.BulkInsert.Extensions; namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql; @@ -13,10 +14,6 @@ public static class PostgreSqlDbContextOptionsExtensions /// public static DbContextOptionsBuilder UseBulkInsertPostgreSql(this DbContextOptionsBuilder optionsBuilder) { - var extension = optionsBuilder.Options.FindExtension>() ?? new BulkInsertOptionsExtension(); - - ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension); - - return optionsBuilder; + return optionsBuilder.UseProvider(); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs index 51425ff..42276d3 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs @@ -5,6 +5,7 @@ using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Logging; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer; @@ -31,12 +32,12 @@ public SqlServerBulkInsertProvider(ILogger? logger protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, - CancellationToken ctk - ) + CancellationToken ctk) { var connection = (SqlConnection) context.Database.GetDbConnection(); var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction() as SqlTransaction; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs index 43e6337..d7222f2 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs @@ -1,5 +1,6 @@ -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore; + +using PhenX.EntityFrameworkCore.BulkInsert.Extensions; namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer; @@ -13,10 +14,6 @@ public static class SqlServerDbContextOptionsExtensions /// public static DbContextOptionsBuilder UseBulkInsertSqlServer(this DbContextOptionsBuilder optionsBuilder) { - var extension = optionsBuilder.Options.FindExtension>() ?? new BulkInsertOptionsExtension(); - - ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension); - - return optionsBuilder; + return optionsBuilder.UseProvider(); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs index 4d9f669..90569b2 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs @@ -1,10 +1,7 @@ -using System.Linq.Expressions; using System.Text; -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; - using PhenX.EntityFrameworkCore.BulkInsert.Dialect; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer; @@ -17,37 +14,39 @@ internal class SqlServerDialectBuilder : SqlDialectBuilder protected override bool SupportsMoveRows => false; - public override string BuildMoveDataSql(DbContext context, string source, - string target, - IProperty[] insertedProperties, - IProperty[] properties, - BulkInsertOptions options, OnConflictOptions? onConflict = null) + public override string BuildMoveDataSql( + TableMetadata target, + string source, + IReadOnlyList insertedProperties, + IReadOnlyList properties, + BulkInsertOptions options, + OnConflictOptions? onConflict = null) { - var insertedColumns = insertedProperties.Select(p => Quote(p.GetColumnName())).ToArray(); + var insertedColumns = insertedProperties.Select(x => x.QuotedColumName); var insertedColumnList = string.Join(", ", insertedColumns); - var returnedColumns = properties.Select(p => $"INSERTED.{p.GetColumnName()} AS {p.GetColumnName()}"); + var returnedColumns = properties.Select(p => $"INSERTED.{p.ColumnName} AS {p.ColumnName}"); var columnList = string.Join(", ", returnedColumns); var q = new StringBuilder(); if (options.CopyGeneratedColumns) { - q.AppendLine($"SET IDENTITY_INSERT {target} ON;"); + q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} ON;"); } // Merge handling if (onConflict is OnConflictOptions onConflictTyped && onConflictTyped.Match != null) { - var matchColumns = GetColumns(context, onConflictTyped.Match); + var matchColumns = GetColumns(target, onConflictTyped.Match); var matchOn = string.Join(" AND ", matchColumns.Select(col => $"TARGET.{col} = SOURCE.{col}")); var updateSet = onConflictTyped.Update != null - ? string.Join(", ", GetUpdates(context, insertedProperties, onConflictTyped.Update)) + ? string.Join(", ", GetUpdates(target, insertedProperties, onConflictTyped.Update)) : null; - q.AppendLine($"MERGE INTO {target} AS TARGET"); + q.AppendLine($"MERGE INTO {target.QuotedTableName} AS TARGET"); q.AppendLine( $"USING (SELECT {string.Join(", ", insertedColumns)} FROM {source}) AS SOURCE ({insertedColumnList})"); q.AppendLine($"ON {matchOn}"); @@ -69,7 +68,7 @@ public override string BuildMoveDataSql(DbContext context, string source, // No conflict handling else { - q.AppendLine($"INSERT INTO {target} ({insertedColumnList})"); + q.AppendLine($"INSERT INTO {target.QuotedTableName} ({insertedColumnList})"); if (columnList.Length != 0) { @@ -86,7 +85,7 @@ public override string BuildMoveDataSql(DbContext context, string source, if (options.CopyGeneratedColumns) { - q.AppendLine($"SET IDENTITY_INSERT {target} OFF;"); + q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} OFF;"); } return q.ToString(); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs index 0eea33f..808a8d8 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs @@ -6,7 +6,7 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; -using PhenX.EntityFrameworkCore.BulkInsert.Extensions; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite; @@ -81,18 +81,18 @@ private static SqliteType GetSqliteType(Type clrType) throw new InvalidOperationException("Unknown Sqlite type for " + clrType); } - private DbCommand GetInsertCommand(DbContext context, Type entityType, string tableName, + private DbCommand GetInsertCommand(DbContext context, TableMetadata tableInfo, string tableName, BulkInsertOptions options, int batchSize) { - var columns = context.GetProperties(entityType, options.CopyGeneratedColumns); + var columns = tableInfo.GetProperties(options.CopyGeneratedColumns); var cmd = context.Database.GetDbConnection().CreateCommand(); var sqliteColumns = columns .Select(c => new { - Name = c.GetColumnName(), - Type = GetSqliteType(c.GetProviderClrType() ?? c.ClrType) + Name = c.ColumnName, + Type = GetSqliteType(c.ProviderClrType ?? c.ClrType) }) .ToArray(); @@ -126,18 +126,19 @@ private DbCommand GetInsertCommand(DbContext context, Type entityType, string ta protected override async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk ) where T : class { const int maxParams = 1000; var batchSize = options.BatchSize ?? 5; - batchSize = Math.Min(batchSize, maxParams / properties.Length); + batchSize = Math.Min(batchSize, maxParams / properties.Count); - await using var insertCommand = GetInsertCommand(context, typeof(T), tableName, options, batchSize); + await using var insertCommand = GetInsertCommand(context, tableInfo, tableName, options, batchSize); foreach (var chunk in entities.Chunk(batchSize)) { @@ -150,7 +151,7 @@ CancellationToken ctk // Last chunk else { - var partialInsertCommand = GetInsertCommand(context, typeof(T), tableName, options, chunk.Length); + var partialInsertCommand = GetInsertCommand(context, tableInfo, tableName, options, chunk.Length); FillValues(chunk, partialInsertCommand.Parameters, properties); await ExecuteCommand(sync, partialInsertCommand, ctk); @@ -171,7 +172,7 @@ private static async Task ExecuteCommand(bool sync, DbCommand insertCommand, Can } } - private static void FillValues(T[] chunk, DbParameterCollection parameters, PropertyAccessor[] properties) where T : class + private static void FillValues(T[] chunk, DbParameterCollection parameters, IReadOnlyList properties) where T : class { var index = 0; foreach (var entity in chunk) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs index 5550e47..87f3b74 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs @@ -1,5 +1,6 @@ using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Infrastructure; + +using PhenX.EntityFrameworkCore.BulkInsert.Extensions; namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite; @@ -13,9 +14,7 @@ public static class SqliteDbContextOptionsExtensions /// public static DbContextOptionsBuilder UseBulkInsertSqlite(this DbContextOptionsBuilder optionsBuilder) { - var extension = optionsBuilder.Options.FindExtension>() ?? new BulkInsertOptionsExtension(); - ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension); - return optionsBuilder; + return optionsBuilder.UseProvider(); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs index 344cebd..2430fe3 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs @@ -1,5 +1,7 @@ using Microsoft.EntityFrameworkCore; +using PhenX.EntityFrameworkCore.BulkInsert.Dialect; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.Abstractions; @@ -15,6 +17,7 @@ internal interface IBulkInsertProvider internal Task> BulkInsertReturnEntities( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, @@ -27,9 +30,12 @@ internal Task> BulkInsertReturnEntities( internal Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, CancellationToken ctk = default ) where T : class; + + SqlDialectBuilder SqlDialect { get; } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index 119e004..780d4e2 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -1,48 +1,44 @@ using System.Data.Common; using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Logging; using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; using PhenX.EntityFrameworkCore.BulkInsert.Dialect; using PhenX.EntityFrameworkCore.BulkInsert.Extensions; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert; -internal abstract class BulkInsertProviderBase : IBulkInsertProvider +#pragma warning disable CS9113 // Parameter is unread. +internal abstract class BulkInsertProviderBase(ILogger>? logger = null) : IBulkInsertProvider +#pragma warning restore CS9113 // Parameter is unread. where TDialect : SqlDialectBuilder, new() { protected readonly TDialect SqlDialect = new(); - private readonly ILogger>? Logger; protected virtual string BulkInsertId => "_bulk_insert_id"; protected abstract string CreateTableCopySql { get; } protected abstract string AddTableCopyBulkInsertId { get; } - protected BulkInsertProviderBase(ILogger>? logger = null) - { - Logger = logger; - } + SqlDialectBuilder IBulkInsertProvider.SqlDialect => SqlDialect; protected async Task CreateTableCopyAsync( bool sync, DbContext context, BulkInsertOptions options, + TableMetadata tableInfo, CancellationToken cancellationToken = default) where T : class { - var tableInfo = GetTableInfo(context, typeof(T)); - var tableName = QuoteTableName(tableInfo.SchemaName, tableInfo.TableName); - var tempTableName = QuoteTableName(null, GetTempTableName(tableInfo.TableName)); + var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName)); + var tempColumns = string.Join(", ", tableInfo.GetProperties(options.CopyGeneratedColumns).Select(x => x.QuotedColumName)); - var keptColumns = string.Join(", ", GetQuotedColumns(context, typeof(T), options.CopyGeneratedColumns)); - var query = string.Format(CreateTableCopySql, tempTableName, tableName, keptColumns); + var query = string.Format(CreateTableCopySql, tempTableName, tableInfo.QuotedTableName, tempColumns); await ExecuteAsync(sync, context, query, cancellationToken); - await AddBulkInsertIdColumn(sync, context, tempTableName, cancellationToken); return tempTableName; @@ -80,6 +76,7 @@ protected static async Task ExecuteAsync(bool sync, DbContext context, string qu public async Task> CopyFromTempTableAsync( bool sync, DbContext context, + TableMetadata tableInfo, string tempTableName, bool returnData, BulkInsertOptions options, @@ -89,6 +86,7 @@ public async Task> CopyFromTempTableAsync( return await CopyFromTempTableWithoutKeysAsync( sync, context, + tableInfo, tempTableName, returnData, options, @@ -99,6 +97,7 @@ public async Task> CopyFromTempTableAsync( private async Task> CopyFromTempTableWithoutKeysAsync( bool sync, DbContext context, + TableMetadata tableInfo, string tempTableName, bool returnData, BulkInsertOptions options, @@ -107,12 +106,10 @@ private async Task> CopyFromTempTableWithoutKeysAsync( where T : class where TResult : class { - var (schemaName, tableName, _) = GetTableInfo(context, typeof(T)); - var quotedTableName = QuoteTableName(schemaName, tableName); - var movedProperties = context.GetProperties(typeof(T), options.CopyGeneratedColumns); - var returnedProperties = returnData ? context.GetProperties(typeof(T)) : []; + var movedProperties = tableInfo.GetProperties(options.CopyGeneratedColumns); + var returnedProperties = returnData ? tableInfo.GetProperties() : []; - var query = SqlDialect.BuildMoveDataSql(context, tempTableName, quotedTableName, movedProperties, returnedProperties, options, onConflict); + var query = SqlDialect.BuildMoveDataSql(tableInfo, tempTableName, movedProperties, returnedProperties, options, onConflict); if (returnData) { @@ -142,37 +139,64 @@ static async Task> QueryAsync(bool sync, DbContext context, string public virtual async Task> BulkInsertReturnEntities( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, CancellationToken ctk = default ) where T : class { - var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk); + List result; - var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk); + var connectionInfo = await context.GetConnection(sync, ctk); + try + { + var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - var result = await CopyFromTempTableAsync(sync, context, tableName, true, options, onConflict, cancellationToken: ctk); + result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk); - await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk); + // Commit the transaction if we own them. + await Commit(sync, connectionInfo, ctk); + } + finally + { + await Finish(sync, connectionInfo, ctk); + } return result; } - private static async Task Finish(bool sync, DbConnection connection, bool wasClosed, - IDbContextTransaction transaction, bool wasBegan, CancellationToken ctk) + private static async Task Commit(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk) { + var (_, _, transaction, wasBegan) = connectionInfo; + if (!wasBegan) { if (sync) { // ReSharper disable once MethodHasAsyncOverloadWithCancellation transaction.Commit(); - transaction.Dispose(); } else { await transaction.CommitAsync(ctk); + } + } + } + + private static async Task Finish(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk) + { + var (connection, wasClosed, transaction, wasBegan) = connectionInfo; + + if (!wasBegan) + { + if (sync) + { + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + transaction.Dispose(); + } + else + { await transaction.DisposeAsync(); } } @@ -194,6 +218,7 @@ private static async Task Finish(bool sync, DbConnection connection, bool wasClo public virtual async Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, @@ -202,23 +227,31 @@ public virtual async Task BulkInsert( { if (onConflict != null) { - var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk); - - var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk); + var connectionInfo = await context.GetConnection(sync, ctk); + try + { + var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - await CopyFromTempTableAsync(sync, context, tableName, false, options, onConflict, ctk); + await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); - await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk); + // Commit the transaction if we own them. + await Commit(sync, connectionInfo, ctk); + } + finally + { + await Finish(sync, connectionInfo, ctk); + } } else { - await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: false, ctk: ctk); + await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: false, ctk: ctk); } } private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, bool tempTableRequired, @@ -229,22 +262,27 @@ public virtual async Task BulkInsert( throw new InvalidOperationException("No entities to insert."); } - var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk); + var connectionInfo = await context.GetConnection(sync, ctk); var tableName = tempTableRequired - ? await CreateTableCopyAsync(sync, context, options, ctk) - : GetQuotedTableName(context, typeof(T)); + ? await CreateTableCopyAsync(sync, context, options, tableInfo, ctk) + : tableInfo.QuotedTableName; - var properties = context - .GetProperties(typeof(T), options.CopyGeneratedColumns) - .Select(p => new PropertyAccessor(p)) - .ToArray(); + var properties = tableInfo.GetProperties(options.CopyGeneratedColumns); - await BulkInsert(false, context, entities, tableName, properties, options, ctk); + try + { + await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk); - await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk); + // Commit the transaction if we own them. + await Commit(sync, connectionInfo, ctk); + } + finally + { + await Finish(sync, connectionInfo, ctk); + } - return (tableName, connection); + return (tableName, connectionInfo.Connection); } /// @@ -253,43 +291,11 @@ public virtual async Task BulkInsert( protected abstract Task BulkInsert( bool sync, DbContext context, + TableMetadata tableInfo, IEnumerable entities, string tableName, - PropertyAccessor[] properties, + IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk ) where T : class; - - /// - /// Get table information for the given entity type : schema name, table name and primary key. - /// - public static (string? SchemaName, string TableName, IKey PrimaryKey) GetTableInfo(DbContext context, Type entityType) - { - var entityTypeInfo = context.Model.FindEntityType(entityType); - var schema = (entityTypeInfo ?? throw new InvalidOperationException($"Could not determine entity type for type {entityType.Name}")).GetSchema(); - var tableName = entityTypeInfo.GetTableName(); - - if (string.IsNullOrWhiteSpace(tableName)) - { - throw new InvalidOperationException($"Could not determine table name for type {entityType.Name}"); - } - - return (schema, tableName, entityTypeInfo.FindPrimaryKey()!); - } - - protected string GetQuotedTableName(DbContext context, Type entityType) - { - var (schema, tableName, _) = GetTableInfo(context, entityType); - - return QuoteTableName(schema, tableName); - } - - protected string QuoteTableName(string? schema, string table) => SqlDialect.QuoteTableName(schema, table); - - protected string[] GetQuotedColumns(DbContext context, Type entityType, bool includeGenerated = true) - { - return context.GetProperties(entityType, includeGenerated) - .Select(p => Quote(p.GetColumnName())) - .ToArray(); - } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs index 9da4b66..72ef576 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs @@ -1,10 +1,7 @@ using System.Linq.Expressions; using System.Text; -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Internal; - +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; namespace PhenX.EntityFrameworkCore.BulkInsert.Dialect; @@ -17,36 +14,10 @@ internal abstract class SqlDialectBuilder protected virtual string ConcatOperator => "||"; protected virtual bool SupportsMoveRows => true; - /// - /// Gets the name of the column for a property in a given entity type. - /// - /// The DbContext - /// The property name - /// The entity type - /// The column name - /// Thrown when the entity type or property is not found. - protected string GetColumnName(DbContext context, string propName) - { - var entityType = context.Model.FindEntityType(typeof(TEntity)); - if (entityType == null) - { - throw new InvalidOperationException($"Entity type {typeof(TEntity).Name} not found in the model."); - } - - var property = entityType.FindProperty(propName); - if (property == null) - { - throw new InvalidOperationException($"Property {propName} not found in entity type {typeof(TEntity).Name}."); - } - - return Quote(property.GetColumnName()); - } - /// /// Builds the SQL for moving data from one table to another. /// - /// The DbContext - /// Source table name + /// Source table /// Target table name /// Properties to be copied /// Properties to be returned @@ -54,16 +25,17 @@ protected string GetColumnName(DbContext context, string propName) /// On conflict options /// Entity type /// The SQL query - public virtual string BuildMoveDataSql(DbContext context, string source, - string target, - IProperty[] insertedProperties, - IProperty[] properties, + public virtual string BuildMoveDataSql( + TableMetadata target, + string source, + IReadOnlyList insertedProperties, + IReadOnlyList properties, BulkInsertOptions options, OnConflictOptions? onConflict = null) { - var insertedColumns = insertedProperties.Select(p => Quote(p.GetColumnName())); + var insertedColumns = insertedProperties.Select(p => p.QuotedColumName); var insertedColumnList = string.Join(", ", insertedColumns); - var returnedColumns = properties.Select(p => Quote(p.GetColumnName())); + var returnedColumns = properties.Select(p => p.QuotedColumName); var columnList = string.Join(", ", returnedColumns); var q = new StringBuilder(); @@ -80,7 +52,7 @@ DELETE FROM {source} } q.AppendLine($""" - INSERT INTO {target} ({insertedColumnList}) + INSERT INTO {target.QuotedTableName} ({insertedColumnList}) SELECT {insertedColumnList} FROM {source} WHERE TRUE @@ -95,13 +67,13 @@ WHERE TRUE if (onConflictTyped.Match != null) { q.Append(' '); - AppendConflictMatch(q, GetColumns(context, onConflictTyped.Match)); + AppendConflictMatch(q, GetColumns(target, onConflictTyped.Match)); } if (onConflictTyped.Update != null) { q.Append(' '); - AppendOnConflictUpdate(q, GetUpdates(context, insertedProperties, onConflictTyped.Update)); + AppendOnConflictUpdate(q, GetUpdates(target, insertedProperties, onConflictTyped.Update)); } if (onConflictTyped.Condition != null) @@ -127,7 +99,7 @@ WHERE TRUE return q.ToString(); } - protected virtual void AppendDoNothing(StringBuilder sql, IProperty[] insertedProperties) + protected virtual void AppendDoNothing(StringBuilder sql, IEnumerable insertedProperties) { sql.AppendLine("DO NOTHING"); } @@ -183,7 +155,7 @@ protected virtual void AppendConflictCondition(StringBuilder sql, OnConflictO /// protected virtual string GetExcludedColumnName(string columnName) { - return $"EXCLUDED.{columnName}"; + return $"EXCLUDED.{Quote(columnName)}"; } /// @@ -204,15 +176,15 @@ public string QuoteTableName(string? schema, string tableName) /// /// Gets column names for the insert statement, from an object initializer. /// - protected string[] GetColumns(DbContext context, Expression> columns) + protected string[] GetColumns(TableMetadata table, Expression> columns) { return columns.Body switch { NewExpression newExpression => newExpression.Arguments.OfType() - .Select(m => GetColumnName(context, m.Member.Name)) + .Select(m => table.GetQuotedColumnName(m.Member.Name)) .ToArray(), MemberExpression memberExpression => [ - GetColumnName(context, memberExpression.Member.Name) + table.GetQuotedColumnName(memberExpression.Member.Name) ], _ => throw new NotSupportedException("Unsupported expression type") }; @@ -229,7 +201,7 @@ protected string[] GetColumns(DbContext context, Expression> /// var updates = GetUpdates(context, e => e.Prop1); /// /// - protected IEnumerable GetUpdates(DbContext context, IProperty[] properties, Expression> update) + protected IEnumerable GetUpdates(TableMetadata table, IEnumerable properties, Expression> update) { switch (update.Body) { @@ -237,7 +209,7 @@ protected IEnumerable GetUpdates(DbContext context, IProperty[] prope { foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member))) { - yield return $"{GetColumnName(context, arg.member.Name)} = {ToSqlExpression(context, arg.expr)}"; + yield return $"{table.GetColumnName(arg.member.Name)} = {ToSqlExpression(table, arg.expr)}"; } break; @@ -246,20 +218,18 @@ protected IEnumerable GetUpdates(DbContext context, IProperty[] prope { foreach (var binding in memberInit.Bindings.OfType()) { - yield return $"{GetColumnName(context, binding.Member.Name)} = {ToSqlExpression(context, binding.Expression)}"; + yield return $"{table.GetColumnName(binding.Member.Name)} = {ToSqlExpression(table, binding.Expression)}"; } break; } case MemberExpression memberExpr: - yield return $"{GetColumnName(context, memberExpr.Member.Name)} = {ToSqlExpression(context, memberExpr)}"; + yield return $"{table.GetColumnName(memberExpr.Member.Name)} = {ToSqlExpression(table, memberExpr)}"; break; case ParameterExpression parameterExpr when (parameterExpr.Type == typeof(T)): foreach (var property in properties) { - var columName = property.GetColumnName(); - - yield return $"{Quote(columName)} = {GetExcludedColumnName(columName)}"; + yield return $"{property.QuotedColumName} = {GetExcludedColumnName(property.ColumnName)}"; } break; @@ -272,21 +242,21 @@ protected IEnumerable GetUpdates(DbContext context, IProperty[] prope /// /// Converts an expression to an SQL string. /// - /// The DbContext + /// The DbContext /// The expression, with simple operations /// Entity type /// An SQL statement /// Thrown when an expression could not be translated. - private string ToSqlExpression(DbContext context, Expression expr) + private string ToSqlExpression(TableMetadata table, Expression expr) { switch (expr) { case MemberExpression m: - return GetExcludedColumnName(GetColumnName(context, m.Member.Name)); + return GetExcludedColumnName(table.GetColumnName(m.Member.Name)); case BinaryExpression b: - var left = ToSqlExpression(context, b.Left); - var right = ToSqlExpression(context, b.Right); + var left = ToSqlExpression(table, b.Left); + var right = ToSqlExpression(table, b.Right); var op = b.NodeType switch { ExpressionType.Add => b.Type == typeof(string) ? ConcatOperator : "+", @@ -328,18 +298,18 @@ private string ToSqlExpression(DbContext context, Expression expr) case UnaryExpression u: if (u.NodeType == ExpressionType.Convert) { - return ToSqlExpression(context, u.Operand); + return ToSqlExpression(table, u.Operand); } if (u.NodeType == ExpressionType.Not) { - return $"NOT ({ToSqlExpression(context, u.Operand)})"; + return $"NOT ({ToSqlExpression(table, u.Operand)})"; } throw new NotSupportedException($"Unary operator not supported: {u.NodeType}"); case MethodCallExpression mce: // Supporte quelques méthodes courantes (ToLower, ToUpper, Trim, etc.) - var objSql = mce.Object != null ? ToSqlExpression(context, mce.Object) : null; - var argsSql = mce.Arguments.Select(expr1 => ToSqlExpression(context, expr1)).ToArray(); + var objSql = mce.Object != null ? ToSqlExpression(table, mce.Object) : null; + var argsSql = mce.Arguments.Select(expr1 => ToSqlExpression(table, expr1)).ToArray(); switch (mce.Method.Name) { case "ToLower": diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs index 00f8241..ec07fe0 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs @@ -1,14 +1,16 @@ using System.Data; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; + namespace PhenX.EntityFrameworkCore.BulkInsert; internal class EnumerableDataReader : IDataReader { private readonly IEnumerator _enumerator; - private readonly PropertyAccessor[] _properties; + private readonly IReadOnlyList _properties; private readonly Dictionary _ordinalMap; - public EnumerableDataReader(IEnumerable rows, PropertyAccessor[] properties) + public EnumerableDataReader(IEnumerable rows, IReadOnlyList properties) { _enumerator = rows.GetEnumerator(); _properties = properties; @@ -32,7 +34,7 @@ public virtual object GetValue(int i) return DBNull.Value; } - return _properties[i].GetValue(current); + return _properties[i].GetValue(current)!; } public int GetValues(object[] values) @@ -43,18 +45,18 @@ public int GetValues(object[] values) return 0; } - for (var i = 0; i < _properties.Length; i++) + for (var i = 0; i < _properties.Count; i++) { - values[i] = _properties[i].GetValue(current); + values[i] = _properties[i].GetValue(current)!; } - return _properties.Length; + return _properties.Count; } public bool Read() => _enumerator.MoveNext(); - public int FieldCount => _properties.Length; - public Type GetFieldType(int i) => _properties[i].ProviderClrType; + public int FieldCount => _properties.Count; + public Type GetFieldType(int i) => _properties[i].ClrType; public int GetOrdinal(string name) => _ordinalMap.GetValueOrDefault(name, -1); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs new file mode 100644 index 0000000..2d8df8b --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/ConnectionInfo.cs @@ -0,0 +1,7 @@ +using System.Data.Common; + +using Microsoft.EntityFrameworkCore.Storage; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions; + +internal readonly record struct ConnectionInfo(DbConnection Connection, bool WasClosed, IDbContextTransaction Transaction, bool WasBegan); diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs index 666b120..2b4bd16 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs @@ -1,28 +1,36 @@ using System.Data; -using System.Data.Common; using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Infrastructure; + +using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; +using PhenX.EntityFrameworkCore.BulkInsert.Metadata; namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions; internal static class DbContextExtensions { - /// - /// Gets cached properties for an entity type, using reflection if not already cached. - /// - internal static IProperty[] GetProperties(this DbContext context, Type entityType, bool includeGenerated = true) + public static TableMetadata GetTableInfo(this DbContext context) { - var entityTypeInfo = context.Model.FindEntityType(entityType) ?? throw new InvalidOperationException($"Could not determine entity type for type {entityType.Name}"); + var provider = context.GetService(); + + return provider.GetTableInfo(context); + } + + public static DbContextOptionsBuilder UseProvider(this DbContextOptionsBuilder optionsBuilder) + where TProvider : class, IBulkInsertProvider + { + var extension = optionsBuilder.Options.FindExtension>() ?? new BulkInsertOptionsExtension(); + + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension); + + ((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension( + optionsBuilder.Options.FindExtension() ?? new()); - return entityTypeInfo - .GetProperties() - .Where(p => !p.IsShadowProperty() && (includeGenerated || p.ValueGenerated != ValueGenerated.OnAdd)) - .ToArray(); + return optionsBuilder; } - internal static async Task<(DbConnection connection, bool wasClosed, IDbContextTransaction transaction, bool wasBegan)> GetConnection( + internal static async Task GetConnection( this DbContext context, bool sync, CancellationToken ctk = default) { var connection = context.Database.GetDbConnection(); @@ -59,6 +67,6 @@ internal static IProperty[] GetProperties(this DbContext context, Type entityTyp } } - return (connection, wasClosed, transaction, wasBegan); + return new ConnectionInfo(connection, wasClosed, transaction, wasBegan); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs index 4d76cee..570c085 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs @@ -1,4 +1,4 @@ -using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; @@ -23,8 +23,9 @@ public static async Task> ExecuteBulkInsertReturnEntitiesAsync( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - return await provider.BulkInsertReturnEntities(false, context, entities, options, onConflict, ctk); + return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk); } /// @@ -53,8 +54,9 @@ public static async Task ExecuteBulkInsertAsync( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - await provider.BulkInsert(false, context, entities, options, onConflict, ctk); + await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk); } /// @@ -82,8 +84,9 @@ public static List ExecuteBulkInsertReturnEntities( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - return provider.BulkInsertReturnEntities(true, context, entities, options, onConflict).GetAwaiter().GetResult(); + return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); } /// @@ -116,8 +119,9 @@ public static void ExecuteBulkInsert( ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); + var tableInfo = dbSet.GetDbContext().GetTableInfo(); - provider.BulkInsert(true, context, entities, options, onConflict).GetAwaiter().GetResult(); + provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); } /// diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs new file mode 100644 index 0000000..cfc94be --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProvider.cs @@ -0,0 +1,48 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; + +using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal sealed class MetadataProvider +{ + private Dictionary? _tables; + + public TableMetadata GetTableInfo(DbContext context) + { + var tables = GetTables(context); + + if (!tables.TryGetValue(typeof(T), out var table)) + { + throw new InvalidOperationException($"Cannot find metadata for type '{typeof(T)}'."); + } + + return table; + } + + private Dictionary GetTables(DbContext context) + { + if (_tables != null) + { + return _tables; + } + + lock (this) + { + if (_tables != null) + { + return _tables; + } + + var provider = context.GetService(); + + _tables = + context.Model.GetEntityTypes() + .ToDictionary( + x => x.ClrType, + x => new TableMetadata(x, provider.SqlDialect)); + return _tables; + } + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProviderExtension.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProviderExtension.cs new file mode 100644 index 0000000..3fe2a20 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/MetadataProviderExtension.cs @@ -0,0 +1,42 @@ +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.Extensions.DependencyInjection; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal class MetadataProviderExtension : IDbContextOptionsExtension +{ + public DbContextOptionsExtensionInfo Info + => new MetadataProviderExtensionInfo(this); + + public void ApplyServices(IServiceCollection services) + { + services.AddSingleton(); + } + + public void Validate(IDbContextOptions options) + { + } + + private class MetadataProviderExtensionInfo : DbContextOptionsExtensionInfo + { + public MetadataProviderExtensionInfo(IDbContextOptionsExtension extension) + : base(extension) { } + + /// + public override int GetServiceProviderHashCode() => 0; + + /// + public override bool ShouldUseSameServiceProvider(DbContextOptionsExtensionInfo other) => true; + + /// + public override bool IsDatabaseProvider => false; + + /// + public override string LogFragment => "MetadataProviderExtension"; + + /// + public override void PopulateDebugInfo(IDictionary debugInfo) + { + } + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs new file mode 100644 index 0000000..b503899 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyAccessor.cs @@ -0,0 +1,66 @@ +using System.Reflection.Emit; +using System.Reflection; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal static class PropertyAccessor +{ + public delegate TValue Getter(TSource source); + + public static Getter CreateUntypedGetter(PropertyInfo propertyInfo, Type sourceType, Type valueType) + { + var method = + typeof(PropertyAccessor).GetMethod(nameof(CreateInternalUntypedGetter), BindingFlags.NonPublic | BindingFlags.Static)! + .MakeGenericMethod(sourceType, valueType); + + return (Getter)method.Invoke(null, [propertyInfo])!; + } + + private static Getter CreateInternalUntypedGetter(PropertyInfo propertyInfo) + { + var getter = CreateGetter(propertyInfo); + + return source => getter((TSource)source!); + } + + public static Getter CreateGetter(PropertyInfo propertyInfo) + { + if (!propertyInfo.CanRead) + { + return x => throw new NotSupportedException(); + } + + var bakingField = + propertyInfo.DeclaringType!.GetField($"<{propertyInfo.Name}>k__BackingField", + BindingFlags.NonPublic | + BindingFlags.Instance); + + var propertyGetMethod = propertyInfo.GetGetMethod()!; + + var getMethod = new DynamicMethod(propertyGetMethod.Name, typeof(TValue), [typeof(TSource)], true); + var getGenerator = getMethod.GetILGenerator(); + + // Load this to stack. + getGenerator.Emit(OpCodes.Ldarg_0); + + if (bakingField != null && !propertyGetMethod.IsVirtual) + { + // Get field directly. + getGenerator.Emit(OpCodes.Ldfld, bakingField); + } + else if (propertyGetMethod.IsVirtual) + { + // Call the virtual property. + getGenerator.Emit(OpCodes.Callvirt, propertyGetMethod); + } + else + { + // Call the non virtual property. + getGenerator.Emit(OpCodes.Call, propertyGetMethod); + } + + getGenerator.Emit(OpCodes.Ret); + + return getMethod.CreateDelegate>(); + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs new file mode 100644 index 0000000..14d9786 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/PropertyMetadata.cs @@ -0,0 +1,61 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +using PhenX.EntityFrameworkCore.BulkInsert.Dialect; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal sealed class PropertyMetadata(IProperty property, SqlDialectBuilder dialect) +{ + private readonly PropertyAccessor.Getter _getter = BuildGetter(property); + + public string Name { get; } = property.Name; + + public string ColumnName { get; } = property.GetColumnName(); + + public string QuotedColumName { get; } = dialect.Quote(property.GetColumnName()); + + public Type ClrType { get; } = property.ClrType; + + public Type? ProviderClrType { get; } = property.GetProviderClrType(); + + public bool IsGenerated { get; } = property.ValueGenerated == ValueGenerated.OnAdd; + + public object? GetValue(object entity) + { + return _getter(entity!); + } + + private static PropertyAccessor.Getter BuildGetter(IProperty property) + { + var valueConverter = + property.GetValueConverter() ?? + property.GetTypeMapping().Converter; + + var actualGetter = + PropertyAccessor.CreateUntypedGetter( + property.PropertyInfo!, + property.DeclaringType.ClrType, + property.ClrType); + + var result = actualGetter; + if (valueConverter != null) + { + var converter = valueConverter.ConvertToProvider; + + result = source => + { + var value = actualGetter(source); + + return converter(value); + }; + } + + return result; + } + + public override string ToString() + { + return $"Name: {Name}, Column: {ColumnName}"; + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs new file mode 100644 index 0000000..b733345 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/TableMetadata.cs @@ -0,0 +1,46 @@ +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Metadata; + +using PhenX.EntityFrameworkCore.BulkInsert.Dialect; + +namespace PhenX.EntityFrameworkCore.BulkInsert.Metadata; + +internal sealed class TableMetadata(IEntityType entityType, SqlDialectBuilder dialect) +{ + private IReadOnlyList? _notGeneratedProperties; + + public string TableName { get; } = + entityType.GetTableName() ?? throw new InvalidOperationException("Canot determine table name."); + + public string QuotedTableName { get; } = + dialect.QuoteTableName(entityType.GetSchema(), entityType.GetTableName()!); + + public IReadOnlyList Properties { get; } = + entityType.GetProperties().Where(p => !p.IsShadowProperty()).Select(x => new PropertyMetadata(x, dialect)).ToList(); + + public IReadOnlyList GetProperties(bool includeGenerated = true) + { + if (includeGenerated) + { + return Properties; + } + + return _notGeneratedProperties ??= Properties.Where(x => !x.IsGenerated).ToList(); + } + + public string GetQuotedColumnName(string propertyName) + { + var property = Properties.FirstOrDefault(x => x.Name == propertyName) + ?? throw new InvalidOperationException($"Property {propertyName} not found in entity type {entityType.Name}."); + + return property.QuotedColumName; + } + + public string GetColumnName(string propertyName) + { + var property = Properties.FirstOrDefault(x => x.Name == propertyName) + ?? throw new InvalidOperationException($"Property {propertyName} not found in entity type {entityType.Name}."); + + return property.ColumnName; + } +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj b/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj index b59ec6a..17da02a 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/PhenX.EntityFrameworkCore.BulkInsert.csproj @@ -6,10 +6,10 @@ - - - - + + + + diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs deleted file mode 100644 index 992dab1..0000000 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/PropertyAccessor.cs +++ /dev/null @@ -1,40 +0,0 @@ -using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Metadata; - -namespace PhenX.EntityFrameworkCore.BulkInsert; - -internal readonly struct PropertyAccessor -{ - private Func ValueGetter { get; } - - private IProperty Property { get; } - - public Type ProviderClrType { get; } - - public string Name => Property.Name; - - public string ColumnName => Property.GetColumnName(); - - public PropertyAccessor(IProperty property) - { - Property = property; - - var propInfo = property.PropertyInfo!; - - var valueConverter = property.GetValueConverter()?? - property.GetTypeMapping().Converter; - - if (valueConverter != null) - { - var conv = valueConverter.ConvertToProvider; - ValueGetter = v => conv(propInfo.GetValue(v)); - ProviderClrType = valueConverter.ProviderClrType; - return; - } - - ValueGetter = propInfo.GetValue; - ProviderClrType = Nullable.GetUnderlyingType(propInfo.PropertyType) ?? propInfo.PropertyType; - } - - public object GetValue(object entity) => ValueGetter(entity) ?? DBNull.Value; -} diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs index ad7e843..508fc70 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparatorSqlServer.cs @@ -5,7 +5,6 @@ using Microsoft.EntityFrameworkCore; -using PhenX.EntityFrameworkCore.BulkInsert.Sqlite; using PhenX.EntityFrameworkCore.BulkInsert.SqlServer; using Testcontainers.MsSql;