Skip to content

Commit c360fd1

Browse files
Merge branch 'async-enumerable' into logging
# Conflicts: # src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs
2 parents 542eece + f1ef1b1 commit c360fd1

5 files changed

Lines changed: 135 additions & 66 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null)
2323
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";
2424

2525
/// <inheritdoc />
26-
public override Task<List<T>> BulkInsertReturnEntities<T>(
26+
public override IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
2727
bool sync,
2828
DbContext context,
2929
TableMetadata tableInfo,

src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ internal interface IBulkInsertProvider
1414
/// <summary>
1515
/// Calls the provider to perform a bulk insert operation.
1616
/// </summary>
17-
internal Task<List<T>> BulkInsertReturnEntities<T>(
17+
internal IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
1818
bool sync,
1919
DbContext context,
2020
TableMetadata tableInfo,

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using System.Runtime.CompilerServices;
2+
13
using Microsoft.EntityFrameworkCore;
24
using Microsoft.EntityFrameworkCore.Storage;
35
using Microsoft.Extensions.Logging;
@@ -22,14 +24,14 @@ namespace PhenX.EntityFrameworkCore.BulkInsert;
2224

2325
SqlDialectBuilder IBulkInsertProvider.SqlDialect => SqlDialect;
2426

25-
public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
27+
public virtual async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
2628
bool sync,
2729
DbContext context,
2830
TableMetadata tableInfo,
2931
IEnumerable<T> entities,
3032
BulkInsertOptions options,
3133
OnConflictOptions? onConflict,
32-
CancellationToken ctk) where T : class
34+
[EnumeratorCancellation] CancellationToken ctk) where T : class
3335
{
3436
using var activity = Telemetry.ActivitySource.StartActivity("BulkInsertReturnEntities");
3537
activity?.AddTag("tableName", tableInfo.TableName);
@@ -45,11 +47,17 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
4547

4648
var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
4749

48-
var result = await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk);
50+
var result =
51+
await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk)
52+
?? throw new InvalidOperationException("Copy returns null enumerable.");
53+
54+
await foreach (var item in result.WithCancellation(ctk))
55+
{
56+
yield return item;
57+
}
4958

5059
// Commit the transaction if we own them.
5160
await connection.Commit(sync, ctk);
52-
return result;
5361
}
5462
finally
5563
{
@@ -173,7 +181,7 @@ protected virtual async Task AddBulkInsertIdColumn<T>(
173181
await ExecuteAsync(sync, context, alterQuery, ctk);
174182
}
175183

176-
private async Task<List<TResult>> CopyFromTempTableAsync<T, TResult>(
184+
private async Task<IAsyncEnumerable<TResult>?> CopyFromTempTableAsync<T, TResult>(
177185
bool sync,
178186
DbContext context,
179187
TableMetadata tableInfo,
@@ -195,19 +203,12 @@ private async Task<List<TResult>> CopyFromTempTableAsync<T, TResult>(
195203
if (returnData)
196204
{
197205
// Use EF to execute the query and return the results
198-
var queryable = context.Set<TResult>().FromSqlRaw(query);
199-
200-
if (sync)
201-
{
202-
return [.. queryable];
203-
}
204-
205-
return await queryable.ToListAsync(ctk);
206+
return context.Set<TResult>().FromSqlRaw(query).AsAsyncEnumerable();
206207
}
207208

208209
// If not returning data, just execute the command
209210
await ExecuteAsync(sync, context, query, ctk);
210-
return [];
211+
return null;
211212
}
212213

213214
protected static async Task ExecuteAsync(

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs

Lines changed: 92 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,116 +12,147 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;
1212
public static class DbSetExtensions
1313
{
1414
/// <summary>
15-
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
15+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant).
1616
/// </summary>
17-
public static async Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
17+
public static List<T> ExecuteBulkInsertReturnEntities<T>(
1818
this DbSet<T> dbSet,
1919
IEnumerable<T> entities,
2020
Action<BulkInsertOptions>? configure = null,
21-
OnConflictOptions<T>? onConflict = null,
22-
CancellationToken ctk = default
21+
OnConflictOptions<T>? onConflict = null
2322
) where T : class
2423
{
25-
var provider = InitProvider(dbSet, configure, out var context, out var options);
26-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
24+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult();
25+
}
26+
27+
/// <summary>
28+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant).
29+
/// </summary>
30+
public static List<T> ExecuteBulkInsertReturnEntities<T>(
31+
this DbContext dbContext,
32+
IEnumerable<T> entities,
33+
Action<BulkInsertOptions>? configure = null,
34+
OnConflictOptions<T>? onConflict = null
35+
) where T : class
36+
{
37+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
2738

28-
return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk);
39+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult();
2940
}
3041

3142
/// <summary>
3243
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext.
3344
/// </summary>
34-
public static async Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions<T>? onConflict = null, CancellationToken cancellationToken = default) where T : class
45+
public static Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
46+
this DbContext dbContext,
47+
IEnumerable<T> entities,
48+
Action<BulkInsertOptions>? configure = null,
49+
OnConflictOptions<T>? onConflict = null,
50+
CancellationToken ctk = default
51+
) where T : class
3552
{
36-
var dbSet = dbContext.Set<T>();
37-
if (dbSet == null)
38-
{
39-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
40-
}
53+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
4154

42-
return await dbSet.ExecuteBulkInsertReturnEntitiesAsync(entities, configure, onConflict, cancellationToken);
55+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk);
4356
}
4457

4558
/// <summary>
46-
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet.
59+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
4760
/// </summary>
48-
public static async Task ExecuteBulkInsertAsync<T>(
61+
public static Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
4962
this DbSet<T> dbSet,
5063
IEnumerable<T> entities,
5164
Action<BulkInsertOptions>? configure = null,
5265
OnConflictOptions<T>? onConflict = null,
5366
CancellationToken ctk = default
5467
) where T : class
68+
{
69+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk);
70+
}
71+
72+
private static async Task<List<T>> ExecuteBulkInsertReturnEntitiesCoreAsync<T>(
73+
this DbSet<T> dbSet,
74+
bool sync,
75+
IEnumerable<T> entities,
76+
Action<BulkInsertOptions>? configure,
77+
OnConflictOptions<T>? onConflict,
78+
CancellationToken ctk
79+
) where T : class
5580
{
5681
var provider = InitProvider(dbSet, configure, out var context, out var options);
57-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
5882

59-
await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk);
83+
var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
84+
85+
var result = new List<T>();
86+
await foreach (var item in enumerable.WithCancellation(ctk))
87+
{
88+
result.Add(item);
89+
}
90+
91+
return result;
6092
}
6193

6294
/// <summary>
63-
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext.
95+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext.
6496
/// </summary>
65-
public static async Task ExecuteBulkInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions<T>? onConflict = null, CancellationToken cancellationToken = default) where T : class
97+
public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T>(
98+
this DbContext dbContext,
99+
IEnumerable<T> entities,
100+
Action<BulkInsertOptions>? configure = null,
101+
OnConflictOptions<T>? onConflict = null,
102+
CancellationToken ctk = default
103+
) where T : class
66104
{
67-
var dbSet = dbContext.Set<T>();
68-
if (dbSet == null)
69-
{
70-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
71-
}
105+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
72106

73-
await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, cancellationToken);
107+
return dbSet.ExecuteBulkInsertReturnEnumerableAsync(entities, configure, onConflict, ctk);
74108
}
75109

76110
/// <summary>
77-
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant).
111+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
78112
/// </summary>
79-
public static List<T> ExecuteBulkInsertReturnEntities<T>(
113+
public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T>(
80114
this DbSet<T> dbSet,
81115
IEnumerable<T> entities,
82116
Action<BulkInsertOptions>? configure = null,
83-
OnConflictOptions<T>? onConflict = null
117+
OnConflictOptions<T>? onConflict = null,
118+
CancellationToken ctk = default
84119
) where T : class
85120
{
86121
var provider = InitProvider(dbSet, configure, out var context, out var options);
87-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
88122

89-
return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult();
123+
return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
90124
}
91125

92126
/// <summary>
93-
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant).
127+
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext.
94128
/// </summary>
95-
public static List<T> ExecuteBulkInsertReturnEntities<T>(
129+
public static async Task ExecuteBulkInsertAsync<T>(
96130
this DbContext dbContext,
97131
IEnumerable<T> entities,
98132
Action<BulkInsertOptions>? configure = null,
99-
OnConflictOptions<T>? onConflict = null
133+
OnConflictOptions<T>? onConflict = null,
134+
CancellationToken ctk = default
100135
) where T : class
101136
{
102-
var dbSet = dbContext.Set<T>();
103-
if (dbSet == null)
104-
{
105-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
106-
}
137+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
107138

108-
return dbSet.ExecuteBulkInsertReturnEntities(entities, configure, onConflict);
139+
await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, ctk);
109140
}
110141

111142
/// <summary>
112-
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant).
143+
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet.
113144
/// </summary>
114-
public static void ExecuteBulkInsert<T>(
145+
public static async Task ExecuteBulkInsertAsync<T>(
115146
this DbSet<T> dbSet,
116147
IEnumerable<T> entities,
117148
Action<BulkInsertOptions>? configure = null,
118-
OnConflictOptions<T>? onConflict = null
149+
OnConflictOptions<T>? onConflict = null,
150+
CancellationToken ctk = default
119151
) where T : class
120152
{
121153
var provider = InitProvider(dbSet, configure, out var context, out var options);
122-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
123154

124-
provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult();
155+
await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
125156
}
126157

127158
/// <summary>
@@ -134,15 +165,26 @@ public static void ExecuteBulkInsert<T>(
134165
OnConflictOptions<T>? onConflict = null
135166
) where T : class
136167
{
137-
var dbSet = dbContext.Set<T>();
138-
if (dbSet == null)
139-
{
140-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
141-
}
168+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
142169

143170
dbSet.ExecuteBulkInsert(entities, configure, onConflict);
144171
}
145172

173+
/// <summary>
174+
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant).
175+
/// </summary>
176+
public static void ExecuteBulkInsert<T>(
177+
this DbSet<T> dbSet,
178+
IEnumerable<T> entities,
179+
Action<BulkInsertOptions>? configure = null,
180+
OnConflictOptions<T>? onConflict = null
181+
) where T : class
182+
{
183+
var provider = InitProvider(dbSet, configure, out var context, out var options);
184+
185+
provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict).GetAwaiter().GetResult();
186+
}
187+
146188
private static DbContext GetDbContext<T>(this DbSet<T> dbSet) where T : class
147189
{
148190
IInfrastructure<IServiceProvider> infrastructure = dbSet;

tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,32 @@ public async Task InsertsEntitiesAndReturn()
9292
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
9393
}
9494

95+
[SkippableFact]
96+
public async Task InsertsEntitiesAndReturnAsyncEnumerable()
97+
{
98+
Skip.If(_context.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));
99+
100+
// Arrange
101+
var entities = new List<TestEntity>
102+
{
103+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity1" },
104+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity2" }
105+
};
106+
107+
// Act
108+
var enumerable = _context.ExecuteBulkInsertReturnEnumerableAsync(entities);
109+
var insertedEntities = new List<TestEntity>();
110+
await foreach (var item in enumerable)
111+
{
112+
insertedEntities.Add(item);
113+
}
114+
115+
// Assert
116+
Assert.Equal(2, insertedEntities.Count);
117+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1");
118+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
119+
}
120+
95121
[SkippableFact]
96122
public void InsertsEntitiesAndReturn_Sync()
97123
{

0 commit comments

Comments
 (0)