Skip to content

Commit f00e9eb

Browse files
author
fabien.menager
committed
Refactor DbContext extension methods to use DbContext instead of DbConnection and improve transaction handling
1 parent 5157d9d commit f00e9eb

7 files changed

Lines changed: 116 additions & 30 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql;
1212
internal class PostgreSqlBulkInsertProvider : BulkInsertProviderBase<PostgreSqlDialectBuilder>
1313
{
1414
//language=sql
15+
/// <inheritdoc />
1516
protected override string CreateTableCopySql => "CREATE TEMPORARY TABLE {0} AS TABLE {1} WITH NO DATA;";
1617

1718
//language=sql
19+
/// <inheritdoc />
1820
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD COLUMN {BulkInsertId} SERIAL PRIMARY KEY;";
1921

2022
private string GetBinaryImportCommand(DbContext context, Type entityType, string tableName)
@@ -24,6 +26,7 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string
2426
return $"COPY {tableName} ({string.Join(", ", columns)}) FROM STDIN (FORMAT BINARY)";
2527
}
2628

29+
/// <inheritdoc />
2730
protected override async Task BulkInsert<T>(
2831
DbContext context,
2932
IEnumerable<T> entities,

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using Microsoft.Data.SqlClient;
44
using Microsoft.EntityFrameworkCore;
5+
using Microsoft.EntityFrameworkCore.Storage;
56

67
using PhenX.EntityFrameworkCore.BulkInsert.Options;
78

@@ -11,22 +12,25 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;
1112
internal class SqlServerBulkInsertProvider : BulkInsertProviderBase<SqlServerDialectBuilder>
1213
{
1314
//language=sql
15+
/// <inheritdoc />
1416
protected override string CreateTableCopySql => "SELECT {2} INTO {0} FROM {1} WHERE 1 = 0;";
1517

1618
//language=sql
19+
/// <inheritdoc />
1720
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD {BulkInsertId} INT IDENTITY PRIMARY KEY;";
1821

22+
/// <inheritdoc />
1923
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";
2024

25+
/// <inheritdoc />
2126
protected override async Task BulkInsert<T>(DbContext context, IEnumerable<T> entities,
2227
string tableName,
2328
PropertyAccessor[] properties, BulkInsertOptions options, CancellationToken ctk)
2429
{
2530
var connection = context.Database.GetDbConnection();
31+
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction() as SqlTransaction;
2632

27-
await using var t = (SqlTransaction) await connection.BeginTransactionAsync(ctk); // TODO option
28-
29-
using var bulkCopy = new SqlBulkCopy(connection as SqlConnection, SqlBulkCopyOptions.TableLock, t);
33+
using var bulkCopy = new SqlBulkCopy(connection as SqlConnection, SqlBulkCopyOptions.TableLock, sqlTransaction);
3034
bulkCopy.DestinationTableName = tableName;
3135
bulkCopy.BatchSize = options.BatchSize ?? 50_000;
3236
bulkCopy.BulkCopyTimeout = 60;
@@ -37,7 +41,5 @@ protected override async Task BulkInsert<T>(DbContext context, IEnumerable<T> en
3741
}
3842

3943
await bulkCopy.WriteToServerAsync(new EnumerableDataReader<T>(entities, properties), ctk);
40-
41-
await t.CommitAsync(ctk);
4244
}
4345
}

src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,19 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite;
1313
[UsedImplicitly]
1414
internal class SqliteBulkInsertProvider : BulkInsertProviderBase<SqliteDialectBuilder>
1515
{
16+
/// <inheritdoc />
1617
protected override string BulkInsertId => "rowid";
1718

1819
//language=sql
20+
/// <inheritdoc />
1921
protected override string CreateTableCopySql => "CREATE TEMP TABLE {0} AS SELECT * FROM {1} WHERE 0;";
2022

2123
//language=sql
24+
/// <inheritdoc />
2225
protected override string AddTableCopyBulkInsertId => "--"; // No need to add an ID column in SQLite
2326

24-
protected override Task AddBulkInsertIdColumn<T>(DbConnection connection, CancellationToken cancellationToken,
27+
/// <inheritdoc />
28+
protected override Task AddBulkInsertIdColumn<T>(DbContext context, CancellationToken cancellationToken,
2529
string tempTableName) where T : class
2630
{
2731
return Task.CompletedTask;
@@ -111,13 +115,10 @@ private DbCommand GetInsertCommand(DbContext context, Type entityType, string ta
111115
return cmd;
112116
}
113117

118+
/// <inheritdoc />
114119
protected override async Task BulkInsert<T>(DbContext context, IEnumerable<T> entities,
115120
string tableName, PropertyAccessor[] properties, BulkInsertOptions options, CancellationToken ctk) where T : class
116121
{
117-
var connection = context.Database.GetDbConnection();
118-
119-
await using var transaction = await connection.BeginTransactionAsync(ctk);
120-
121122
const int maxParams = 1000;
122123
var batchSize = options.BatchSize ?? 5;
123124
batchSize = Math.Min(batchSize, maxParams / properties.Length);
@@ -142,8 +143,6 @@ protected override async Task BulkInsert<T>(DbContext context, IEnumerable<T> en
142143
await partialInsertCommand.ExecuteNonQueryAsync(ctk);
143144
}
144145
}
145-
146-
await transaction.CommitAsync(ctk);
147146
}
148147

149148
private static void FillValues<T>(T[] chunk, DbParameterCollection parameters, PropertyAccessor[] properties) where T : class

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
using Microsoft.EntityFrameworkCore;
44
using Microsoft.EntityFrameworkCore.Metadata;
5+
using Microsoft.EntityFrameworkCore.Storage;
56

67
using PhenX.EntityFrameworkCore.BulkInsert.Abstractions;
78
using PhenX.EntityFrameworkCore.BulkInsert.Dialect;
@@ -31,27 +32,28 @@ protected async Task<string> CreateTableCopyAsync<T>(
3132

3233
var keptColumns = string.Join(", ", GetQuotedColumns(context, typeof(T), false));
3334
var query = string.Format(CreateTableCopySql, tempTableName, tableName, keptColumns);
34-
await ExecuteAsync(connection, query, cancellationToken);
35+
await ExecuteAsync(context, query, cancellationToken);
3536

36-
await AddBulkInsertIdColumn<T>(connection, cancellationToken, tempTableName);
37+
await AddBulkInsertIdColumn<T>(context, cancellationToken, tempTableName);
3738

3839
return tempTableName;
3940
}
4041

41-
protected virtual async Task AddBulkInsertIdColumn<T>(DbConnection connection, CancellationToken cancellationToken,
42+
protected virtual async Task AddBulkInsertIdColumn<T>(DbContext context, CancellationToken cancellationToken,
4243
string tempTableName) where T : class
4344
{
4445
var alterQuery = string.Format(AddTableCopyBulkInsertId, tempTableName);
45-
await ExecuteAsync(connection, alterQuery, cancellationToken);
46+
await ExecuteAsync(context, alterQuery, cancellationToken);
4647
}
4748

4849
protected virtual string GetTempTableName(string tableName) => $"_temp_bulk_insert_{tableName}";
4950

5051
protected string Quote(string name) => SqlDialect.Quote(name);
5152

52-
protected static async Task ExecuteAsync(DbConnection connection, string query, CancellationToken cancellationToken = default)
53+
protected static async Task ExecuteAsync(DbContext context, string query, CancellationToken cancellationToken = default)
5354
{
54-
var command = connection.CreateCommand();
55+
var command = context.Database.GetDbConnection().CreateCommand();
56+
command.Transaction = context.Database.CurrentTransaction!.GetDbTransaction();
5557
command.CommandText = query;
5658

5759
await command.ExecuteNonQueryAsync(cancellationToken);
@@ -103,7 +105,7 @@ private async Task<List<TResult>> CopyFromTempTableWithoutKeysAsync<T, TResult>(
103105
}
104106

105107
// If not returning data, just execute the command
106-
await ExecuteAsync(connection, query, cancellationToken);
108+
await ExecuteAsync(context, query, cancellationToken);
107109
return [];
108110
}
109111

@@ -115,12 +117,17 @@ public async Task<List<T>> BulkInsertWithIdentityAsync<T>(
115117
CancellationToken ctk = default
116118
) where T : class
117119
{
118-
var (connection, wasClosed) = await context.GetConnection(ctk);
120+
var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(ctk);
119121

120122
var (tableName, _) = await PerformBulkInsertAsync(context, entities, options, tempTableRequired: true, ctk: ctk);
121123

122124
var result = await CopyFromTempTableAsync<T>(context, connection, tableName, true, options, onConflict, cancellationToken: ctk);
123125

126+
if (!wasBegan)
127+
{
128+
await transaction.CommitAsync(ctk);
129+
}
130+
124131
if (wasClosed)
125132
{
126133
await connection.CloseAsync();
@@ -139,12 +146,17 @@ public async Task BulkInsertWithoutReturnAsync<T>(
139146
{
140147
if (onConflict != null)
141148
{
142-
var (connection, wasClosed) = await context.GetConnection(ctk);
149+
var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(ctk);
143150

144151
var (tableName, _) = await PerformBulkInsertAsync(context, entities, options, tempTableRequired: true, ctk: ctk);
145152

146153
await CopyFromTempTableAsync<T>(context, connection, tableName, false, options, onConflict, ctk);
147154

155+
if (!wasBegan)
156+
{
157+
await transaction.CommitAsync(ctk);
158+
}
159+
148160
if (wasClosed)
149161
{
150162
await connection.CloseAsync();
@@ -156,7 +168,8 @@ public async Task BulkInsertWithoutReturnAsync<T>(
156168
}
157169
}
158170

159-
private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync<T>(DbContext context,
171+
private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync<T>(
172+
DbContext context,
160173
IEnumerable<T> entities,
161174
BulkInsertOptions options,
162175
bool tempTableRequired,
@@ -167,7 +180,7 @@ public async Task BulkInsertWithoutReturnAsync<T>(
167180
throw new InvalidOperationException("No entities to insert.");
168181
}
169182

170-
var (connection, wasClosed) = await context.GetConnection(ctk);
183+
var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(ctk);
171184

172185
var tableName = tempTableRequired
173186
? await CreateTableCopyAsync<T>(context, connection, ctk)
@@ -180,6 +193,11 @@ public async Task BulkInsertWithoutReturnAsync<T>(
180193

181194
await BulkInsert(context, entities, tableName, properties, options, ctk);
182195

196+
if (!wasBegan)
197+
{
198+
await transaction.CommitAsync(ctk);
199+
}
200+
183201
if (wasClosed)
184202
{
185203
await connection.CloseAsync();
@@ -188,8 +206,17 @@ public async Task BulkInsertWithoutReturnAsync<T>(
188206
return (tableName, connection);
189207
}
190208

191-
protected abstract Task BulkInsert<T>(DbContext context, IEnumerable<T> entities,
192-
string tableName, PropertyAccessor[] properties, BulkInsertOptions options, CancellationToken ctk) where T : class;
209+
/// <summary>
210+
/// The main bulk insert method: will insert either in a temp table or directly in the target table.
211+
/// </summary>
212+
protected abstract Task BulkInsert<T>(
213+
DbContext context,
214+
IEnumerable<T> entities,
215+
string tableName,
216+
PropertyAccessor[] properties,
217+
BulkInsertOptions options,
218+
CancellationToken ctk
219+
) where T : class;
193220

194221
/// <summary>
195222
/// Get table information for the given entity type : schema name, table name and primary key.

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using Microsoft.EntityFrameworkCore;
55
using Microsoft.EntityFrameworkCore.Metadata;
6+
using Microsoft.EntityFrameworkCore.Storage;
67

78
namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;
89

@@ -11,7 +12,7 @@ internal static class DbContextExtensions
1112
/// <summary>
1213
/// Gets cached properties for an entity type, using reflection if not already cached.
1314
/// </summary>
14-
public static IProperty[] GetProperties(this DbContext context, Type entityType, bool includeGenerated = true)
15+
internal static IProperty[] GetProperties(this DbContext context, Type entityType, bool includeGenerated = true)
1516
{
1617
var entityTypeInfo = context.Model.FindEntityType(entityType) ?? throw new InvalidOperationException($"Could not determine entity type for type {entityType.Name}");
1718

@@ -21,7 +22,7 @@ public static IProperty[] GetProperties(this DbContext context, Type entityType,
2122
.ToArray();
2223
}
2324

24-
public static async Task<(DbConnection connection, bool wasClosed)> GetConnection(this DbContext context, CancellationToken ctk = default)
25+
internal static async Task<(DbConnection connection, bool wasClosed, IDbContextTransaction transaction, bool wasBegan)> GetConnection(this DbContext context, CancellationToken ctk = default)
2526
{
2627
var connection = context.Database.GetDbConnection();
2728
var wasClosed = connection.State == ConnectionState.Closed;
@@ -31,6 +32,15 @@ public static IProperty[] GetProperties(this DbContext context, Type entityType,
3132
await connection.OpenAsync(ctk);
3233
}
3334

34-
return (connection, wasClosed);
35+
var wasBegan = true;
36+
var transaction = context.Database.CurrentTransaction;
37+
38+
if (transaction == null)
39+
{
40+
wasBegan = false;
41+
transaction = await context.Database.BeginTransactionAsync(ctk);
42+
}
43+
44+
return (connection, wasClosed, transaction, wasBegan);
3545
}
3646
}

tests/PhenX.EntityFrameworkCore.BulkInsert.Benchmark/LibComparator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public abstract class LibComparator
1414
public int N;
1515

1616
private IList<TestEntity> data = [];
17-
protected TestDbContext DbContext;
17+
protected TestDbContext DbContext { get; set; } = null!;
1818

1919
[IterationSetup]
2020
public void IterationSetup()

tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
using System;
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Threading.Tasks;
@@ -257,6 +257,51 @@ public async Task InsertAndRead_EntityWithValueConverters()
257257
Assert.Contains(inserted, e => e.Name == "Entity2" && e.CreatedAt == now.AddDays(-1));
258258
}
259259

260+
[Fact]
261+
public async Task BulkInsert_WithOpenTransaction_CommitsSuccessfully()
262+
{
263+
// Arrange
264+
var entities = new List<TestEntity>
265+
{
266+
new TestEntity { Name = "EntityWithTx1" },
267+
new TestEntity { Name = "EntityWithTx2" }
268+
};
269+
270+
await using var transaction = await DbContainer.DbContext.Database.BeginTransactionAsync();
271+
272+
await DbContainer.DbContext.ExecuteInsertAsync(entities);
273+
274+
await transaction.CommitAsync();
275+
276+
// Assert
277+
var insertedEntities = DbContainer.DbContext.TestEntities.ToList();
278+
Assert.Contains(insertedEntities, e => e.Name == "EntityWithTx1");
279+
Assert.Contains(insertedEntities, e => e.Name == "EntityWithTx2");
280+
}
281+
282+
[Fact]
283+
public async Task BulkInsert_WithOpenTransaction_RollsBackOnFailure()
284+
{
285+
// Arrange
286+
var entities = new List<TestEntity>
287+
{
288+
new TestEntity { Name = "EntityWithTxFail1" },
289+
new TestEntity { Name = "EntityWithTxFail2" }
290+
};
291+
292+
await using var transaction = await DbContainer.DbContext.Database.BeginTransactionAsync();
293+
294+
await DbContainer.DbContext.ExecuteInsertAsync(entities);
295+
296+
await transaction.RollbackAsync();
297+
298+
// Assert
299+
DbContainer.DbContext.ChangeTracker.Clear();
300+
var insertedEntities = DbContainer.DbContext.TestEntities.ToList();
301+
Assert.DoesNotContain(insertedEntities, e => e.Name == "EntityWithTxFail1");
302+
Assert.DoesNotContain(insertedEntities, e => e.Name == "EntityWithTxFail2");
303+
}
304+
260305
public Task InitializeAsync() => DbContainer.InitializeAsync();
261306

262307
public Task DisposeAsync() => DbContainer.DisposeAsync();

0 commit comments

Comments
 (0)