Skip to content

Commit 3dc6f95

Browse files
Async enumerable support.
1 parent 656459a commit 3dc6f95

5 files changed

Lines changed: 136 additions & 70 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null)
2727
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";
2828

2929
/// <inheritdoc />
30-
public override Task<List<T>> BulkInsertReturnEntities<T>(
30+
public override IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
3131
bool sync,
3232
DbContext context,
3333
TableMetadata tableInfo,

src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ internal interface IBulkInsertProvider
1414
/// <summary>
1515
/// Calls the provider to perform a bulk insert operation.
1616
/// </summary>
17-
internal Task<List<T>> BulkInsertReturnEntities<T>(
17+
internal IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
1818
bool sync,
1919
DbContext context,
2020
TableMetadata tableInfo,

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Data.Common;
2+
using System.Runtime.CompilerServices;
23

34
using Microsoft.EntityFrameworkCore;
45
using Microsoft.EntityFrameworkCore.Storage;
@@ -73,7 +74,7 @@ protected static async Task ExecuteAsync(bool sync, DbContext context, string qu
7374
}
7475
}
7576

76-
public async Task<List<T>> CopyFromTempTableAsync<T>(
77+
public async Task<IAsyncEnumerable<T>?> CopyFromTempTableAsync<T>(
7778
bool sync,
7879
DbContext context,
7980
TableMetadata tableInfo,
@@ -94,7 +95,7 @@ public async Task<List<T>> CopyFromTempTableAsync<T>(
9495
cancellationToken: cancellationToken);
9596
}
9697

97-
private async Task<List<TResult>> CopyFromTempTableWithoutKeysAsync<T, TResult>(
98+
private async Task<IAsyncEnumerable<TResult>?> CopyFromTempTableWithoutKeysAsync<T, TResult>(
9899
bool sync,
99100
DbContext context,
100101
TableMetadata tableInfo,
@@ -113,47 +114,46 @@ private async Task<List<TResult>> CopyFromTempTableWithoutKeysAsync<T, TResult>(
113114

114115
if (returnData)
115116
{
116-
return await QueryAsync(sync, context, query, cancellationToken);
117+
return Query(context, query);
117118
}
118119

119120
// If not returning data, just execute the command
120121
await ExecuteAsync(sync, context, query, cancellationToken);
121-
return [];
122+
return null;
122123

123-
static async Task<List<TResult>> QueryAsync(bool sync, DbContext context, string query, CancellationToken cancellationToken)
124+
static IAsyncEnumerable<TResult> Query(DbContext context, string query)
124125
{
125126
// Use EF to execute the query and return the results
126127
IQueryable<TResult> queryable = context
127128
.Set<TResult>()
128129
.FromSqlRaw(query);
129130

130-
if (sync)
131-
{
132-
return queryable.ToList();
133-
}
134-
135-
return await queryable.ToListAsync(cancellationToken: cancellationToken);
131+
return queryable.AsAsyncEnumerable();
136132
}
137133
}
138134

139-
public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
135+
public virtual async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
140136
bool sync,
141137
DbContext context,
142138
TableMetadata tableInfo,
143139
IEnumerable<T> entities,
144140
BulkInsertOptions options,
145141
OnConflictOptions? onConflict = null,
146-
CancellationToken ctk = default
142+
[EnumeratorCancellation] CancellationToken ctk = default
147143
) where T : class
148144
{
149-
List<T> result;
150-
151145
var connectionInfo = await context.GetConnection(sync, ctk);
152146
try
153147
{
154148
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
155149

156-
result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk);
150+
var result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk)
151+
?? throw new InvalidOperationException("Failed to get async enumerable.");
152+
153+
await foreach (var item in result)
154+
{
155+
yield return item;
156+
}
157157

158158
// Commit the transaction if we own them.
159159
await Commit(sync, connectionInfo, ctk);
@@ -162,8 +162,6 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
162162
{
163163
await Finish(sync, connectionInfo, ctk);
164164
}
165-
166-
return result;
167165
}
168166

169167
private static async Task Commit(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk)

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbSetExtensions.cs

Lines changed: 92 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,116 +12,147 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;
1212
public static class DbSetExtensions
1313
{
1414
/// <summary>
15-
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
15+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant).
1616
/// </summary>
17-
public static async Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
17+
public static List<T> ExecuteBulkInsertReturnEntities<T>(
1818
this DbSet<T> dbSet,
1919
IEnumerable<T> entities,
2020
Action<BulkInsertOptions>? configure = null,
21-
OnConflictOptions? onConflict = null,
22-
CancellationToken ctk = default
21+
OnConflictOptions? onConflict = null
2322
) where T : class
2423
{
25-
var provider = InitProvider(dbSet, configure, out var context, out var options);
26-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
24+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult();
25+
}
26+
27+
/// <summary>
28+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant).
29+
/// </summary>
30+
public static List<T> ExecuteBulkInsertReturnEntities<T>(
31+
this DbContext dbContext,
32+
IEnumerable<T> entities,
33+
Action<BulkInsertOptions>? configure = null,
34+
OnConflictOptions? onConflict = null
35+
) where T : class
36+
{
37+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
2738

28-
return await provider.BulkInsertReturnEntities(false, context, tableInfo, entities, options, onConflict, ctk);
39+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(true, entities, configure, onConflict, default).GetAwaiter().GetResult();
2940
}
3041

3142
/// <summary>
3243
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext.
3344
/// </summary>
34-
public static async Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class
45+
public static Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
46+
this DbContext dbContext,
47+
IEnumerable<T> entities,
48+
Action<BulkInsertOptions>? configure = null,
49+
OnConflictOptions? onConflict = null,
50+
CancellationToken ctk = default
51+
) where T : class
3552
{
36-
var dbSet = dbContext.Set<T>();
37-
if (dbSet == null)
38-
{
39-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
40-
}
53+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
4154

42-
return await dbSet.ExecuteBulkInsertReturnEntitiesAsync(entities, configure, onConflict, cancellationToken);
55+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk);
4356
}
4457

4558
/// <summary>
46-
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet.
59+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
4760
/// </summary>
48-
public static async Task ExecuteBulkInsertAsync<T>(
61+
public static Task<List<T>> ExecuteBulkInsertReturnEntitiesAsync<T>(
4962
this DbSet<T> dbSet,
5063
IEnumerable<T> entities,
5164
Action<BulkInsertOptions>? configure = null,
5265
OnConflictOptions? onConflict = null,
5366
CancellationToken ctk = default
5467
) where T : class
68+
{
69+
return dbSet.ExecuteBulkInsertReturnEntitiesCoreAsync(false, entities, configure, onConflict, ctk);
70+
}
71+
72+
private static async Task<List<T>> ExecuteBulkInsertReturnEntitiesCoreAsync<T>(
73+
this DbSet<T> dbSet,
74+
bool sync,
75+
IEnumerable<T> entities,
76+
Action<BulkInsertOptions>? configure,
77+
OnConflictOptions? onConflict,
78+
CancellationToken ctk
79+
) where T : class
5580
{
5681
var provider = InitProvider(dbSet, configure, out var context, out var options);
57-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
5882

59-
await provider.BulkInsert(false, context, tableInfo, entities, options, onConflict, ctk);
83+
var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
84+
85+
var result = new List<T>();
86+
await foreach (var item in enumerable.WithCancellation(ctk))
87+
{
88+
result.Add(item);
89+
}
90+
91+
return result;
6092
}
6193

6294
/// <summary>
63-
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext.
95+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext.
6496
/// </summary>
65-
public static async Task ExecuteBulkInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class
97+
public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T>(
98+
this DbContext dbContext,
99+
IEnumerable<T> entities,
100+
Action<BulkInsertOptions>? configure = null,
101+
OnConflictOptions? onConflict = null,
102+
CancellationToken ctk = default
103+
) where T : class
66104
{
67-
var dbSet = dbContext.Set<T>();
68-
if (dbSet == null)
69-
{
70-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
71-
}
105+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
72106

73-
await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, cancellationToken);
107+
return dbSet.ExecuteBulkInsertReturnEnumerableAsync(entities, configure, onConflict, ctk);
74108
}
75109

76110
/// <summary>
77-
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet (synchronous variant).
111+
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbSet.
78112
/// </summary>
79-
public static List<T> ExecuteBulkInsertReturnEntities<T>(
113+
public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T>(
80114
this DbSet<T> dbSet,
81115
IEnumerable<T> entities,
82116
Action<BulkInsertOptions>? configure = null,
83-
OnConflictOptions? onConflict = null
117+
OnConflictOptions? onConflict = null,
118+
CancellationToken ctk = default
84119
) where T : class
85120
{
86121
var provider = InitProvider(dbSet, configure, out var context, out var options);
87-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
88122

89-
return provider.BulkInsertReturnEntities(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult();
123+
return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
90124
}
91125

92126
/// <summary>
93-
/// Executes a bulk insert operation returning the inserted/updated entities, from the DbContext (synchronous variant).
127+
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbContext.
94128
/// </summary>
95-
public static List<T> ExecuteBulkInsertReturnEntities<T>(
129+
public static async Task ExecuteBulkInsertAsync<T>(
96130
this DbContext dbContext,
97131
IEnumerable<T> entities,
98132
Action<BulkInsertOptions>? configure = null,
99-
OnConflictOptions? onConflict = null
133+
OnConflictOptions? onConflict = null,
134+
CancellationToken ctk = default
100135
) where T : class
101136
{
102-
var dbSet = dbContext.Set<T>();
103-
if (dbSet == null)
104-
{
105-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
106-
}
137+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
107138

108-
return dbSet.ExecuteBulkInsertReturnEntities(entities, configure, onConflict);
139+
await dbSet.ExecuteBulkInsertAsync(entities, configure, onConflict, ctk);
109140
}
110141

111142
/// <summary>
112-
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant).
143+
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet.
113144
/// </summary>
114-
public static void ExecuteBulkInsert<T>(
145+
public static async Task ExecuteBulkInsertAsync<T>(
115146
this DbSet<T> dbSet,
116147
IEnumerable<T> entities,
117148
Action<BulkInsertOptions>? configure = null,
118-
OnConflictOptions? onConflict = null
149+
OnConflictOptions? onConflict = null,
150+
CancellationToken ctk = default
119151
) where T : class
120152
{
121153
var provider = InitProvider(dbSet, configure, out var context, out var options);
122-
var tableInfo = dbSet.GetDbContext().GetTableInfo<T>();
123154

124-
provider.BulkInsert(true, context, tableInfo, entities, options, onConflict).GetAwaiter().GetResult();
155+
await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
125156
}
126157

127158
/// <summary>
@@ -134,15 +165,26 @@ public static void ExecuteBulkInsert<T>(
134165
OnConflictOptions? onConflict = null
135166
) where T : class
136167
{
137-
var dbSet = dbContext.Set<T>();
138-
if (dbSet == null)
139-
{
140-
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
141-
}
168+
var dbSet = dbContext.Set<T>() ?? throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
142169

143170
dbSet.ExecuteBulkInsert(entities, configure, onConflict);
144171
}
145172

173+
/// <summary>
174+
/// Executes a bulk insert operation without returning the inserted/updated entities, from the DbSet (synchronous variant).
175+
/// </summary>
176+
public static void ExecuteBulkInsert<T>(
177+
this DbSet<T> dbSet,
178+
IEnumerable<T> entities,
179+
Action<BulkInsertOptions>? configure = null,
180+
OnConflictOptions? onConflict = null
181+
) where T : class
182+
{
183+
var provider = InitProvider(dbSet, configure, out var context, out var options);
184+
185+
provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict).GetAwaiter().GetResult();
186+
}
187+
146188
private static DbContext GetDbContext<T>(this DbSet<T> dbSet) where T : class
147189
{
148190
IInfrastructure<IServiceProvider> infrastructure = dbSet;

tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,32 @@ public async Task InsertsEntitiesAndReturn()
9292
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
9393
}
9494

95+
[SkippableFact]
96+
public async Task InsertsEntitiesAndReturnAsyncENumerable()
97+
{
98+
Skip.If(_context.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));
99+
100+
// Arrange
101+
var entities = new List<TestEntity>
102+
{
103+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity1" },
104+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity2" }
105+
};
106+
107+
// Act
108+
var enumerable = _context.ExecuteBulkInsertReturnEnumerableAsync(entities);
109+
var insertedEntities = new List<TestEntity>();
110+
await foreach (var item in enumerable)
111+
{
112+
insertedEntities.Add(item);
113+
}
114+
115+
// Assert
116+
Assert.Equal(2, insertedEntities.Count);
117+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1");
118+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
119+
}
120+
95121
[SkippableFact]
96122
public void InsertsEntitiesAndReturn_Sync()
97123
{

0 commit comments

Comments
 (0)