Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null)
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";

/// <inheritdoc />
public override Task<List<T>> BulkInsertReturnEntities<T>(
public override IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
bool sync,
DbContext context,
TableMetadata tableInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal interface IBulkInsertProvider
/// <summary>
/// Calls the provider to perform a bulk insert operation.
/// </summary>
internal Task<List<T>> BulkInsertReturnEntities<T>(
internal IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
bool sync,
DbContext context,
TableMetadata tableInfo,
Expand Down
34 changes: 16 additions & 18 deletions src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Data.Common;
using System.Runtime.CompilerServices;

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
Expand Down Expand Up @@ -77,7 +78,7 @@ protected static async Task ExecuteAsync(
}
}

public async Task<List<T>> CopyFromTempTableAsync<T>(
public async Task<IAsyncEnumerable<T>?> CopyFromTempTableAsync<T>(
bool sync,
DbContext context,
TableMetadata tableInfo,
Expand All @@ -98,7 +99,7 @@ public async Task<List<T>> CopyFromTempTableAsync<T>(
ctk);
}

private async Task<List<TResult>> CopyFromTempTableWithoutKeysAsync<T, TResult>(
private async Task<IAsyncEnumerable<TResult>?> CopyFromTempTableWithoutKeysAsync<T, TResult>(
bool sync,
DbContext context,
TableMetadata tableInfo,
Expand All @@ -119,46 +120,45 @@ private async Task<List<TResult>> CopyFromTempTableWithoutKeysAsync<T, TResult>(

if (returnData)
{
return await QueryAsync(sync, context, query, ctk);
return Query(context, query);
}

// If not returning data, just execute the command
await ExecuteAsync(sync, context, query, ctk);
return [];
return null;

static async Task<List<TResult>> QueryAsync(bool sync, DbContext context, string query, CancellationToken cancellationToken)
static IAsyncEnumerable<TResult> Query(DbContext context, string query)
{
// Use EF to execute the query and return the results
IQueryable<TResult> queryable = context
.Set<TResult>()
.FromSqlRaw(query);

if (sync)
{
return queryable.ToList();
}

return await queryable.ToListAsync(cancellationToken: cancellationToken);
return queryable.AsAsyncEnumerable();
}
}

public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
public virtual async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
bool sync,
DbContext context,
TableMetadata tableInfo,
IEnumerable<T> entities,
BulkInsertOptions options,
OnConflictOptions? onConflict,
CancellationToken ctk) where T : class
[EnumeratorCancellation] CancellationToken ctk) where T : class
{
List<T> result;

var connection = await context.GetConnection(sync, ctk);
try
{
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);

result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk);
var result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk)
?? throw new InvalidOperationException("Failed to get async enumerable.");

await foreach (var item in result)
{
yield return item;
}

// Commit the transaction if we own them.
await connection.Commit(sync, ctk);
Expand All @@ -167,8 +167,6 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
{
await connection.Close(sync, ctk);
}

return result;
}

public virtual async Task BulkInsert<T>(
Expand Down
142 changes: 92 additions & 50 deletions src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,116 +12,147 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;
public static class DbSetExtensions
{
/// <summary>
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant).
/// </summary>
public static async Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
public static List<T> ExecuteBulkInsertReturnEntities<T>(
this DbSet<T> dbSet,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
OnConflictOptions<T>? onConflict = null
) where T : class
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult();
}

/// <summary>
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant).
/// </summary>
public static List<T> ExecuteBulkInsertReturnEntities<T>(
this DbContext dbContext,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null
) where T : class
{
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");

return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk);
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult();
}

/// <summary>
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext.
/// </summary>
public static async Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions<T>? onConflict = null, CancellationToken cancellationToken = default) where T : class
public static Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
this DbContext dbContext,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
) where T : class
{
var dbSet = dbContext.Set<T>();
if (dbSet == null)
{
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
}
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");

return await dbSet.ExecuteBulkInsertReturnEntitiesAsync(entities, configure, onConflict, cancellationToken);
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk);
}

/// <summary>
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet.
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
/// </summary>
public static async Task ExecuteBulkInsertAsync<T>(
public static Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
this DbSet<T> dbSet,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
) where T : class
{
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk);
}

private static async Task<List<T>> ExecuteBulkInsertReturnEntitiesCoreAsync<T>(
this DbSet<T> dbSet,
bool sync,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure,
OnConflictOptions<T>? onConflict,
CancellationToken ctk
) where T : class
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();

await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk);
var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);

var result = new List<T>();
await foreach (var item in enumerable.WithCancellation(ctk))
{
result.Add(item);
}

return result;
}

/// <summary>
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext.
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext.
/// </summary>
public static async Task ExecuteBulkInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions<T>? onConflict = null, CancellationToken cancellationToken = default) where T : class
public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T>(
this DbContext dbContext,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
) where T : class
{
var dbSet = dbContext.Set<T>();
if (dbSet == null)
{
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
}
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");

await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, cancellationToken);
return dbSet.ExecuteBulkInsertReturnEnumerableAsync(entities, configure, onConflict, ctk);
}

/// <summary>
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant).
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
/// </summary>
public static List<T> ExecuteBulkInsertReturnEntities<T>(
public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T>(
this DbSet<T> dbSet,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
) where T : class
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();

return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult();
return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
}

/// <summary>
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant).
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext.
/// </summary>
public static List<T> ExecuteBulkInsertReturnEntities<T>(
public static async Task ExecuteBulkInsertAsync<T>(
this DbContext dbContext,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
) where T : class
{
var dbSet = dbContext.Set<T>();
if (dbSet == null)
{
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
}
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");

return dbSet.ExecuteBulkInsertReturnEntities(entities, configure, onConflict);
await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, ctk);
}

/// <summary>
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant).
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet.
/// </summary>
public static void ExecuteBulkInsert<T>(
public static async Task ExecuteBulkInsertAsync<T>(
this DbSet<T> dbSet,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null
OnConflictOptions<T>? onConflict = null,
CancellationToken ctk = default
) where T : class
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();

provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult();
await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
}

/// <summary>
Expand All @@ -134,15 +165,26 @@ public static void ExecuteBulkInsert<T>(
OnConflictOptions<T>? onConflict = null
) where T : class
{
var dbSet = dbContext.Set<T>();
if (dbSet == null)
{
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
}
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");

dbSet.ExecuteBulkInsert(entities, configure, onConflict);
}

/// <summary>
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant).
/// </summary>
public static void ExecuteBulkInsert<T>(
this DbSet<T> dbSet,
IEnumerable<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null
) where T : class
{
var provider = InitProvider(dbSet, configure, out var context, out var options);

provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict).GetAwaiter().GetResult();
}

private static DbContext GetDbContext<T>(this DbSet<T> dbSet) where T : class
{
IInfrastructure<IServiceProvider> infrastructure = dbSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ public async Task InsertsEntitiesAndReturn()
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
}

[SkippableFact]
public async Task InsertsEntitiesAndReturnAsyncEnumerable()
{
Skip.If(_context.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));

// Arrange
var entities = new List<TestEntity>
{
new TestEntity { TestRun = _run, Name = $"{_run}_Entity1" },
new TestEntity { TestRun = _run, Name = $"{_run}_Entity2" }
};

// Act
var enumerable = _context.ExecuteBulkInsertReturnEnumerableAsync(entities);
var insertedEntities = new List<TestEntity>();
await foreach (var item in enumerable)
{
insertedEntities.Add(item);
}

// Assert
Assert.Equal(2, insertedEntities.Count);
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1");
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
}

[SkippableFact]
public void InsertsEntitiesAndReturn_Sync()
{
Expand Down