diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs index f4b3190..b737cb6 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs @@ -23,7 +23,7 @@ public MySqlBulkInsertProvider(ILogger? logger = null) protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}"; /// - public override Task> BulkInsertReturnEntities( + public override IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, TableMetadata tableInfo, diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs index 2430fe3..226b52a 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs @@ -14,7 +14,7 @@ internal interface IBulkInsertProvider /// /// Calls the provider to perform a bulk insert operation. /// - internal Task> BulkInsertReturnEntities( + internal IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, TableMetadata tableInfo, diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index d731913..b223283 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -1,4 +1,5 @@ using System.Data.Common; +using System.Runtime.CompilerServices; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Storage; @@ -77,7 +78,7 @@ protected static async Task ExecuteAsync( } } - public async Task> CopyFromTempTableAsync( + public async Task?> CopyFromTempTableAsync( bool sync, DbContext context, TableMetadata tableInfo, @@ -98,7 +99,7 @@ public async Task> CopyFromTempTableAsync( ctk); } - private async Task> CopyFromTempTableWithoutKeysAsync( + private async Task?> CopyFromTempTableWithoutKeysAsync( bool sync, DbContext context, TableMetadata tableInfo, @@ -119,46 +120,45 @@ private async Task> CopyFromTempTableWithoutKeysAsync( if (returnData) { - return await QueryAsync(sync, context, query, ctk); + return Query(context, query); } // If not returning data, just execute the command await ExecuteAsync(sync, context, query, ctk); - return []; + return null; - static async Task> QueryAsync(bool sync, DbContext context, string query, CancellationToken cancellationToken) + static IAsyncEnumerable Query(DbContext context, string query) { // Use EF to execute the query and return the results IQueryable queryable = context .Set() .FromSqlRaw(query); - if (sync) - { - return queryable.ToList(); - } - - return await queryable.ToListAsync(cancellationToken: cancellationToken); + return queryable.AsAsyncEnumerable(); } } - public virtual async Task> BulkInsertReturnEntities( + public virtual async IAsyncEnumerable BulkInsertReturnEntities( bool sync, DbContext context, TableMetadata tableInfo, IEnumerable entities, BulkInsertOptions options, OnConflictOptions? onConflict, - CancellationToken ctk) where T : class + [EnumeratorCancellation] CancellationToken ctk) where T : class { - List result; - var connection = await context.GetConnection(sync, ctk); try { var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk); - result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk); + var result = await CopyFromTempTableAsync(sync, context, tableInfo, tableName, true, options, onConflict, ctk) + ?? throw new InvalidOperationException("Failed to get async enumerable."); + + await foreach (var item in result) + { + yield return item; + } // Commit the transaction if we own them. await connection.Commit(sync, ctk); @@ -167,8 +167,6 @@ public virtual async Task> BulkInsertReturnEntities( { await connection.Close(sync, ctk); } - - return result; } public virtual async Task BulkInsert( diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs index 6c17bf2..7e773a9 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs @@ -12,116 +12,147 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions; public static class DbSetExtensions { /// - /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet. + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant). /// - public static async Task> ExecuteBulkInsertReturnEntitiesAsync( + public static List ExecuteBulkInsertReturnEntities( this DbSet dbSet, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null, - CancellationToken ctk = default + OnConflictOptions? onConflict = null ) where T : class { - var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult(); + } + + /// + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant). + /// + public static List ExecuteBulkInsertReturnEntities( + this DbContext dbContext, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null + ) where T : class + { + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk); + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult(); } /// /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext. /// - public static async Task> ExecuteBulkInsertReturnEntitiesAsync(this DbContext dbContext, IEnumerable entities, Action? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class + public static Task> ExecuteBulkInsertReturnEntitiesAsync( + this DbContext dbContext, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null, + CancellationToken ctk = default + ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - return await dbSet.ExecuteBulkInsertReturnEntitiesAsync(entities, configure, onConflict, cancellationToken); + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk); } /// - /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet. + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet. /// - public static async Task ExecuteBulkInsertAsync( + public static Task> ExecuteBulkInsertReturnEntitiesAsync( this DbSet dbSet, IEnumerable entities, Action? configure = null, OnConflictOptions? onConflict = null, CancellationToken ctk = default ) where T : class + { + return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk); + } + + private static async Task> ExecuteBulkInsertReturnEntitiesCoreAsync( + this DbSet dbSet, + bool sync, + IEnumerable entities, + Action? configure, + OnConflictOptions? onConflict, + CancellationToken ctk + ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); - await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk); + var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict, ctk); + + var result = new List(); + await foreach (var item in enumerable.WithCancellation(ctk)) + { + result.Add(item); + } + + return result; } /// - /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext. + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext. /// - public static async Task ExecuteBulkInsertAsync(this DbContext dbContext, IEnumerable entities, Action? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class + public static IAsyncEnumerable ExecuteBulkInsertReturnEnumerableAsync( + this DbContext dbContext, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null, + CancellationToken ctk = default + ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, cancellationToken); + return dbSet.ExecuteBulkInsertReturnEnumerableAsync(entities, configure, onConflict, ctk); } /// - /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant). + /// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet. /// - public static List ExecuteBulkInsertReturnEntities( + public static IAsyncEnumerable ExecuteBulkInsertReturnEnumerableAsync( this DbSet dbSet, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null + OnConflictOptions? onConflict = null, + CancellationToken ctk = default ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); - return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); + return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict, ctk); } /// - /// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant). + /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext. /// - public static List ExecuteBulkInsertReturnEntities( + public static async Task ExecuteBulkInsertAsync( this DbContext dbContext, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null + OnConflictOptions? onConflict = null, + CancellationToken ctk = default ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - return dbSet.ExecuteBulkInsertReturnEntities(entities, configure, onConflict); + await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, ctk); } /// - /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant). + /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet. /// - public static void ExecuteBulkInsert( + public static async Task ExecuteBulkInsertAsync( this DbSet dbSet, IEnumerable entities, Action? configure = null, - OnConflictOptions? onConflict = null + OnConflictOptions? onConflict = null, + CancellationToken ctk = default ) where T : class { var provider = InitProvider(dbSet, configure, out var context, out var options); - var tableInfo = dbSet.GetDbContext().GetTableInfo(); - provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult(); + await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict, ctk); } /// @@ -134,15 +165,26 @@ public static void ExecuteBulkInsert( OnConflictOptions? onConflict = null ) where T : class { - var dbSet = dbContext.Set(); - if (dbSet == null) - { - throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); - } + var dbSet = dbContext.Set() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext."); dbSet.ExecuteBulkInsert(entities, configure, onConflict); } + /// + /// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant). + /// + public static void ExecuteBulkInsert( + this DbSet dbSet, + IEnumerable entities, + Action? configure = null, + OnConflictOptions? onConflict = null + ) where T : class + { + var provider = InitProvider(dbSet, configure, out var context, out var options); + + provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo(), entities, options, onConflict).GetAwaiter().GetResult(); + } + private static DbContext GetDbContext(this DbSet dbSet) where T : class { IInfrastructure infrastructure = dbSet; diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs index 5db1690..eaa1976 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs @@ -92,6 +92,32 @@ public async Task InsertsEntitiesAndReturn() Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2"); } + [SkippableFact] + public async Task InsertsEntitiesAndReturnAsyncEnumerable() + { + Skip.If(_context.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase)); + + // Arrange + var entities = new List + { + new TestEntity { TestRun = _run, Name = $"{_run}_Entity1" }, + new TestEntity { TestRun = _run, Name = $"{_run}_Entity2" } + }; + + // Act + var enumerable = _context.ExecuteBulkInsertReturnEnumerableAsync(entities); + var insertedEntities = new List(); + await foreach (var item in enumerable) + { + insertedEntities.Add(item); + } + + // Assert + Assert.Equal(2, insertedEntities.Count); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1"); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2"); + } + [SkippableFact] public void InsertsEntitiesAndReturn_Sync() {