diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index b223283..132390f 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -1,4 +1,3 @@ -using System.Data.Common; using System.Runtime.CompilerServices; using Microsoft.EntityFrameworkCore; @@ -13,10 +12,7 @@ namespace PhenX.EntityFrameworkCore.BulkInsert; -#pragma warning disable CS9113 // Parameter is unread. -internal abstract class BulkInsertProviderBase(ILogger>? logger = null) : IBulkInsertProvider -#pragma warning restore CS9113 // Parameter is unread. - where TDialect : SqlDialectBuilder, new() +internal abstract class BulkInsertProviderBase(ILogger>? logger = null) : IBulkInsertProvider where TDialect : SqlDialectBuilder, new() { protected readonly TDialect SqlDialect = new(); @@ -28,116 +24,6 @@ internal abstract class BulkInsertProviderBase(ILogger SqlDialect; - protected async Task CreateTableCopyAsync( - bool sync, - DbContext context, - BulkInsertOptions options, - TableMetadata tableInfo, - CancellationToken ctk) where T : class - { - var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName)); - var tempColumns = tableInfo.GetProperties(options.CopyGeneratedColumns); - - var query = SqlDialect.CreateTableCopySql(tempTableName, tableInfo, tempColumns); - - await ExecuteAsync(sync, context, query, ctk); - await AddBulkInsertIdColumn(sync, context, tempTableName, ctk); - - return tempTableName; - } - - protected virtual async Task AddBulkInsertIdColumn( - bool sync, - DbContext context, - string tempTableName, - CancellationToken ctk) where T : class - { - var alterQuery = string.Format(AddTableCopyBulkInsertId, tempTableName); - - await ExecuteAsync(sync, context, alterQuery, ctk); - } - - protected static async Task ExecuteAsync( - bool sync, - DbContext context, - string query, - CancellationToken ctk) - { - var command = context.Database.GetDbConnection().CreateCommand(); - command.Transaction = context.Database.CurrentTransaction!.GetDbTransaction(); - command.CommandText = query; - - if (sync) - { - // ReSharper disable once MethodHasAsyncOverloadWithCancellation - command.ExecuteNonQuery(); - } - else - { - await command.ExecuteNonQueryAsync(ctk); - } - } - - public async Task?> CopyFromTempTableAsync( - bool sync, - DbContext context, - TableMetadata tableInfo, - string tempTableName, - bool returnData, - BulkInsertOptions options, - OnConflictOptions? onConflict, - CancellationToken ctk) where T : class - { - return await CopyFromTempTableWithoutKeysAsync( - sync, - context, - tableInfo, - tempTableName, - returnData, - options, - onConflict, - ctk); - } - - private async Task?> CopyFromTempTableWithoutKeysAsync( - bool sync, - DbContext context, - TableMetadata tableInfo, - string tempTableName, - bool returnData, - BulkInsertOptions options, - OnConflictOptions? onConflict, - CancellationToken ctk) where T : class where TResult : class - { - var query = - SqlDialect.BuildMoveDataSql( - tableInfo, - tempTableName, - tableInfo.GetProperties(options.CopyGeneratedColumns), - returnData ? tableInfo.GetProperties() : [], - options, - onConflict); - - if (returnData) - { - return Query(context, query); - } - - // If not returning data, just execute the command - await ExecuteAsync(sync, context, query, ctk); - return null; - - static IAsyncEnumerable Query(DbContext context, string query) - { - // Use EF to execute the query and return the results - IQueryable queryable = context - .Set() - .FromSqlRaw(query); - - return queryable.AsAsyncEnumerable(); - } - } - public virtual async IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, @@ -147,15 +33,25 @@ public virtual async IAsyncEnumerable BulkInsertReturnEntities( OnConflictOptions? onConflict, [EnumeratorCancellation] CancellationToken ctk) where T : class { + using var activity = Telemetry.ActivitySource.StartActivity("BulkInsertReturnEntities"); + activity?.AddTag("tableName", tableInfo.TableName); + activity?.AddTag("synchronous", sync); + var connection = await context.GetConnection(sync, ctk); try { - var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); + if (logger != null) + { + Log.UsingTempTablToReturnData(logger); + } - var result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, ctk) - ?? throw new InvalidOperationException("Failed to get async enumerable."); + var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - await foreach (var item in result) + var result = + await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk) + ?? throw new InvalidOperationException("Copy returns null enumerable."); + + await foreach (var item in result.WithCancellation(ctk)) { yield return item; } @@ -178,30 +74,44 @@ public virtual async Task BulkInsert( OnConflictOptions? onConflict, CancellationToken ctk) where T : class { - if (onConflict != null) + using var activity = Telemetry.ActivitySource.StartActivity("BulkInsert"); + activity?.AddTag("tableName", tableInfo.TableName); + activity?.AddTag("synchronous", sync); + + var connection = await context.GetConnection(sync, ctk); + try { - var connection = await context.GetConnection(sync, ctk); - try + if (onConflict != null) { - var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); + if (logger != null) + { + Log.UsingTempTableToResolveConflicts(logger); + } - await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); + var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - // Commit the transaction if we own them. - await connection.Commit(sync, ctk); + await CopyFromTempTableAsync(sync, context, tableInfo, tableName, false, options, onConflict, ctk); } - finally + else { - await connection.Close(sync, ctk); + if (logger != null) + { + Log.UsingDirectInsert(logger); + } + + await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: false, ctk: ctk); } + + // Commit the transaction if we own them. + await connection.Commit(sync, ctk); } - else + finally { - await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: false, ctk: ctk); + await connection.Close(sync, ctk); } } - private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync( + private async Task PerformBulkInsertAsync( bool sync, DbContext context, TableMetadata tableInfo, @@ -215,27 +125,18 @@ public virtual async Task BulkInsert( throw new InvalidOperationException("No entities to insert."); } - var connection = await context.GetConnection(sync, ctk); - var tableName = tempTableRequired ? await CreateTableCopyAsync(sync, context, options, tableInfo, ctk) : tableInfo.QuotedTableName; var properties = tableInfo.GetProperties(options.CopyGeneratedColumns); - try - { - await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk); + using var activity = Telemetry.ActivitySource.StartActivity("Insert"); + activity?.AddTag("tempTable", tempTableRequired); + activity?.AddTag("synchronous", sync); - // Commit the transaction if we own them. - await connection.Commit(sync, ctk); - } - finally - { - await connection.Close(sync, ctk); - } - - return (tableName, connection.Connection); + await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk); + return tableName; } /// @@ -250,4 +151,84 @@ protected abstract Task BulkInsert( IReadOnlyList properties, BulkInsertOptions options, CancellationToken ctk) where T : class; + + protected async Task CreateTableCopyAsync( + bool sync, + DbContext context, + BulkInsertOptions options, + TableMetadata tableInfo, + CancellationToken ctk) where T : class + { + var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName)); + var tempColumns = tableInfo.GetProperties(options.CopyGeneratedColumns); + + var query = SqlDialect.CreateTableCopySql(tempTableName, tableInfo, tempColumns); + + await ExecuteAsync(sync, context, query, ctk); + await AddBulkInsertIdColumn(sync, context, tempTableName, ctk); + + return tempTableName; + } + + protected virtual async Task AddBulkInsertIdColumn( + bool sync, + DbContext context, + string tempTableName, + CancellationToken ctk) where T : class + { + var alterQuery = string.Format(AddTableCopyBulkInsertId, tempTableName); + + await ExecuteAsync(sync, context, alterQuery, ctk); + } + + private async Task?> CopyFromTempTableAsync( + bool sync, + DbContext context, + TableMetadata tableInfo, + string tempTableName, + bool returnData, + BulkInsertOptions options, + OnConflictOptions? onConflict, + CancellationToken ctk) where T : class where TResult : class + { + var query = + SqlDialect.BuildMoveDataSql( + tableInfo, + tempTableName, + tableInfo.GetProperties(options.CopyGeneratedColumns), + returnData ? tableInfo.GetProperties() : [], + options, + onConflict); + + if (returnData) + { + // Use EF to execute the query and return the results + return context.Set().FromSqlRaw(query).AsAsyncEnumerable(); + } + + // If not returning data, just execute the command + await ExecuteAsync(sync, context, query, ctk); + return null; + } + + protected static async Task ExecuteAsync( + bool sync, + DbContext context, + string query, + CancellationToken ctk) + { + var command = context.Database.GetDbConnection().CreateCommand(); + command.Transaction = context.Database.CurrentTransaction!.GetDbTransaction(); + command.CommandText = query; + + if (sync) + { + // ReSharper disable once MethodHasAsyncOverloadWithCancellation + command.ExecuteNonQuery(); + } + else + { + await command.ExecuteNonQueryAsync(ctk); + } + } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs index ec07fe0..914b360 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs @@ -4,29 +4,18 @@ namespace PhenX.EntityFrameworkCore.BulkInsert; -internal class EnumerableDataReader : IDataReader +internal sealed class EnumerableDataReader(IEnumerable rows, IReadOnlyList properties) : IDataReader { - private readonly IEnumerator _enumerator; - private readonly IReadOnlyList _properties; - private readonly Dictionary _ordinalMap; - - public EnumerableDataReader(IEnumerable rows, IReadOnlyList properties) - { - _enumerator = rows.GetEnumerator(); - _properties = properties; - _ordinalMap = properties - .Select((p, i) => new - { - Property = p, - Index = i, - }) + private readonly IEnumerator _enumerator = rows.GetEnumerator(); + private readonly Dictionary _ordinalMap = + properties + .Select((p, i) => (Property: p, Index: i)) .ToDictionary( p => p.Property.Name, p => p.Index ); - } - public virtual object GetValue(int i) + public object GetValue(int i) { var current = _enumerator.Current; if (current == null) @@ -34,7 +23,7 @@ public virtual object GetValue(int i) return DBNull.Value; } - return _properties[i].GetValue(current)!; + return properties[i].GetValue(current)!; } public int GetValues(object[] values) @@ -45,25 +34,29 @@ public int GetValues(object[] values) return 0; } - for (var i = 0; i < _properties.Count; i++) + for (var i = 0; i < properties.Count; i++) { - values[i] = _properties[i].GetValue(current)!; + values[i] = properties[i].GetValue(current)!; } - return _properties.Count; + return properties.Count; } public bool Read() => _enumerator.MoveNext(); - public int FieldCount => _properties.Count; - public Type GetFieldType(int i) => _properties[i].ClrType; + public Type GetFieldType(int i) => properties[i].ClrType; public int GetOrdinal(string name) => _ordinalMap.GetValueOrDefault(name, -1); + public int FieldCount => properties.Count; + public int Depth => 0; - public bool IsClosed => false; + public int RecordsAffected => 0; + public bool IsClosed => false; + + public void Close() { } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Log.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Log.cs new file mode 100644 index 0000000..c81b727 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Log.cs @@ -0,0 +1,24 @@ +using Microsoft.Extensions.Logging; + +namespace PhenX.EntityFrameworkCore.BulkInsert; + +internal static partial class Log +{ + [LoggerMessage( + EventId = 1000, + Level = LogLevel.Trace, + Message = "Using temporary table to return data")] + public static partial void UsingTempTablToReturnData(ILogger logger); + + [LoggerMessage( + EventId = 1001, + Level = LogLevel.Trace, + Message = "Using temporary table to resolve conflicts")] + public static partial void UsingTempTableToResolveConflicts(ILogger logger); + + [LoggerMessage( + EventId = 1002, + Level = LogLevel.Trace, + Message = "Insert to table directly")] + public static partial void UsingDirectInsert(ILogger logger); +} diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Telemetry.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Telemetry.cs new file mode 100644 index 0000000..61034c1 --- /dev/null +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Telemetry.cs @@ -0,0 +1,14 @@ +using System.Diagnostics; + +namespace PhenX.EntityFrameworkCore.BulkInsert; + +/// +/// Utility class for telemetry. +/// +public static class Telemetry +{ + /// + /// The activity source. + /// + public static readonly ActivitySource ActivitySource = new("PhenX.EntityFrameworkCore.BulkInsert"); +}