From 3dc6f958e77e1d79e9a932edb09307308c542427 Mon Sep 17 00:00:00 2001 From: Sebastian Stehle Date: Thu, 22 May 2025 12:49:09 +0200 Subject: [PATCH 1/2] Async enumerable support. --- .../MySqlBulkInsertProvider.cs | 2 +- .../Abstractions/IBulkInsertProvider.cs | 2 +- .../BulkInsertProviderBase.cs | 34 ++--- .../Extensions/DbSetExtensions.cs | 142 ++++++++++++------ .../Tests/Basic/BasicTestsBase.cs | 26 ++++ 5 files changed, 136 insertions(+), 70 deletions(-) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs index bc72001..c98b1ad 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs @@ -27,7 +27,7 @@ public MySqlBulkInsertProvider(ILogger? logger = null) protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}"; /// - public override Task> BulkInsertReturnEntities( + public override IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, TableMetadata tableInfo, diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs index 2430fe3..226b52a 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs @@ -14,7 +14,7 @@ internal interface IBulkInsertProvider /// /// Calls the provider to perform a bulk insert operation. /// - internal Task> BulkInsertReturnEntities( + internal IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, TableMetadata tableInfo, diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index 780d4e2..f52c915 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -1,4 +1,5 @@ using System.Data.Common; +using System.Runtime.CompilerServices; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Storage; @@ -73,7 +74,7 @@ protected static async Task ExecuteAsync(bool sync, DbContext context, string qu } } - public async Task> CopyFromTempTableAsync( + public async Task?> CopyFromTempTableAsync( bool sync, DbContext context, TableMetadata tableInfo, @@ -94,7 +95,7 @@ public async Task> CopyFromTempTableAsync( cancellationToken: cancellationToken); } - private async Task> CopyFromTempTableWithoutKeysAsync( + private async Task?> CopyFromTempTableWithoutKeysAsync( bool sync, DbContext context, TableMetadata tableInfo, @@ -113,47 +114,46 @@ private async Task> CopyFromTempTableWithoutKeysAsync( if (returnData) { - return await QueryAsync(sync, context, query, cancellationToken); + return Query(context, query); } // If not returning data, just execute the command await ExecuteAsync(sync, context, query, cancellationToken); - return []; + return null; - static async Task> QueryAsync(bool sync, DbContext context, string query, CancellationToken cancellationToken) + static IAsyncEnumerable Query(DbContext context, string query) { // Use EF to execute the query and return the results IQueryable queryable = context .Set() .FromSqlRaw(query); - if (sync) - { - return queryable.ToList(); - } - - return await queryable.ToListAsync(cancellationToken: cancellationToken); + return queryable.AsAsyncEnumerable(); } } - public virtual async Task> BulkInsertReturnEntities( + public virtual async IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict = null, - CancellationToken ctk = default + [EnumeratorCancellation] CancellationToken ctk = default ) where T : class { - List result; - var connectionInfo = await context.GetConnection(sync, ctk); try { var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk); + var result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk) + ?? throw new InvalidOperationException("Failed to get async enumerable."); + + await foreach (var item in result) + { + yield return item; + } // Commit the transaction if we own them. await Commit(sync, connectionInfo, ctk); @@ -162,8 +162,6 @@ public virtual async Task> BulkInsertReturnEntities( { await Finish(sync, connectionInfo, ctk); } - - return result; } private static async Task Commit(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs index 570c085..fae1531 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs @@ -12,116 +12,147 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions; public static class DbSetExtensions { /// - /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet. + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant). /// - public static async Task> ExecuteBulkInsertReturnEntitiesAsync( + public static List ExecuteBulkInsertReturnEntities( this DbSet dbSet, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null, - CancellationToken ctk = default + OnConflictOptions? onConflict = null ) where T : class { - var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult(); + } + + /// + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant). + /// + public static List ExecuteBulkInsertReturnEntities( + this DbContext dbContext, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null + ) where T : class + { + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk); + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult(); } /// /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext. /// - public static async Task> ExecuteBulkInsertReturnEntitiesAsync(this DbContext dbContext, IEnumerable entities, Action? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class + public static Task> ExecuteBulkInsertReturnEntitiesAsync( + this DbContext dbContext, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null, + CancellationToken ctk = default + ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - return await dbSet.ExecuteBulkInsertReturnEntitiesAsync(entities, configure, onConflict, cancellationToken); + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk); } /// - /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet. + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet. /// - public static async Task ExecuteBulkInsertAsync( + public static Task> ExecuteBulkInsertReturnEntitiesAsync( this DbSet dbSet, IEnumerable entities, Action? configure = null, OnConflictOptions? onConflict = null, CancellationToken ctk = default ) where T : class + { + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk); + } + + private static async Task> ExecuteBulkInsertReturnEntitiesCoreAsync( + this DbSet dbSet, + bool sync, + IEnumerable entities, + Action? configure, + OnConflictOptions? onConflict, + CancellationToken ctk + ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); - await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk); + var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict, ctk); + + var result = new List(); + await foreach (var item in enumerable.WithCancellation(ctk)) + { + result.Add(item); + } + + return result; } /// - /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext. + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext. /// - public static async Task ExecuteBulkInsertAsync(this DbContext dbContext, IEnumerable entities, Action? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class + public static IAsyncEnumerable ExecuteBulkInsertReturnEnumerableAsync( + this DbContext dbContext, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null, + CancellationToken ctk = default + ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, cancellationToken); + return dbSet.ExecuteBulkInsertReturnEnumerableAsync(entities, configure, onConflict, ctk); } /// - /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant). + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet. /// - public static List ExecuteBulkInsertReturnEntities( + public static IAsyncEnumerable ExecuteBulkInsertReturnEnumerableAsync( this DbSet dbSet, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null + OnConflictOptions? onConflict = null, + CancellationToken ctk = default ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); - return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); + return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict, ctk); } /// - /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant). + /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext. /// - public static List ExecuteBulkInsertReturnEntities( + public static async Task ExecuteBulkInsertAsync( this DbContext dbContext, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null + OnConflictOptions? onConflict = null, + CancellationToken ctk = default ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - return dbSet.ExecuteBulkInsertReturnEntities(entities, configure, onConflict); + await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, ctk); } /// - /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant). + /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet. /// - public static void ExecuteBulkInsert( + public static async Task ExecuteBulkInsertAsync( this DbSet dbSet, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null + OnConflictOptions? onConflict = null, + CancellationToken ctk = default ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); - provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); + await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict, ctk); } /// @@ -134,15 +165,26 @@ public static void ExecuteBulkInsert( OnConflictOptions? onConflict = null ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); dbSet.ExecuteBulkInsert(entities, configure, onConflict); } + /// + /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant). + /// + public static void ExecuteBulkInsert( + this DbSet dbSet, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null + ) where T : class + { + var provider = InitProvider(dbSet, configure, out var context, out var options); + + provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict).GetAwaiter().GetResult(); + } + private static DbContext GetDbContext(this DbSet dbSet) where T : class { IInfrastructure infrastructure = dbSet; diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs index 62f4bc2..763332e 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs @@ -92,6 +92,32 @@ public async Task InsertsEntitiesAndReturn() Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2"); } + [SkippableFact] + public async Task InsertsEntitiesAndReturnAsyncENumerable() + { + Skip.If(_context.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase)); + + // Arrange + var entities = new List + { + new TestEntity { TestRun = _run, Name = $"{_run}_Entity1" }, + new TestEntity { TestRun = _run, Name = $"{_run}_Entity2" } + }; + + // Act + var enumerable = _context.ExecuteBulkInsertReturnEnumerableAsync(entities); + var insertedEntities = new List(); + await foreach (var item in enumerable) + { + insertedEntities.Add(item); + } + + // Assert + Assert.Equal(2, insertedEntities.Count); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1"); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2"); + } + [SkippableFact] public void InsertsEntitiesAndReturn_Sync() { From f1ef1b152c11908d37da8a91bee3f53fef45756a Mon Sep 17 00:00:00 2001 From: Sebastian Stehle Date: Thu, 22 May 2025 22:52:16 +0200 Subject: [PATCH 2/2] Fix naming. --- .../Tests/Basic/BasicTestsBase.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs index a91636d..eaa1976 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs @@ -93,7 +93,7 @@ public async Task InsertsEntitiesAndReturn() } [SkippableFact] - public async Task InsertsEntitiesAndReturnAsyncENumerable() + public async Task InsertsEntitiesAndReturnAsyncEnumerable() { Skip.If(_context.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));