Skip to content

Commit 7485d81

Browse files
Metadata.
1 parent a317c64 commit 7485d81

6 files changed

Lines changed: 61 additions & 46 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlDbContextOptionsExtensions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.EntityFrameworkCore;
2-
using Microsoft.EntityFrameworkCore.Infrastructure;
32

43
using PhenX.EntityFrameworkCore.BulkInsert.Extensions;
54

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDbContextOptionsExtensions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.EntityFrameworkCore;
2-
using Microsoft.EntityFrameworkCore.Infrastructure;
32

43
using PhenX.EntityFrameworkCore.BulkInsert.Extensions;
54

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ internal class SqlServerDialectBuilder : SqlDialectBuilder
1515
protected override bool SupportsMoveRows => false;
1616

1717
public override string BuildMoveDataSql<T>(
18-
TableMetadata source,
19-
string target,
18+
TableMetadata target,
19+
string source,
2020
IReadOnlyList<PropertyMetadata> insertedProperties,
2121
IReadOnlyList<PropertyMetadata> properties,
2222
BulkInsertOptions options,
@@ -32,21 +32,21 @@ public override string BuildMoveDataSql<T>(
3232

3333
if (options.CopyGeneratedColumns)
3434
{
35-
q.AppendLine($"SET IDENTITY_INSERT {target} ON;");
35+
q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} ON;");
3636
}
3737

3838
// Merge handling
3939
if (onConflict is OnConflictOptions<T> onConflictTyped && onConflictTyped.Match != null)
4040
{
41-
var matchColumns = GetColumns(source, onConflictTyped.Match);
41+
var matchColumns = GetColumns(target, onConflictTyped.Match);
4242
var matchOn = string.Join(" AND ",
4343
matchColumns.Select(col => $"TARGET.{col} = SOURCE.{col}"));
4444

4545
var updateSet = onConflictTyped.Update != null
46-
? string.Join(", ", GetUpdates(source, insertedProperties, onConflictTyped.Update))
46+
? string.Join(", ", GetUpdates(target, insertedProperties, onConflictTyped.Update))
4747
: null;
4848

49-
q.AppendLine($"MERGE INTO {target} AS TARGET");
49+
q.AppendLine($"MERGE INTO {target.QuotedTableName} AS TARGET");
5050
q.AppendLine(
5151
$"USING (SELECT {string.Join(", ", insertedColumns)} FROM {source}) AS SOURCE ({insertedColumnList})");
5252
q.AppendLine($"ON {matchOn}");
@@ -68,7 +68,7 @@ public override string BuildMoveDataSql<T>(
6868
// No conflict handling
6969
else
7070
{
71-
q.AppendLine($"INSERT INTO {target} ({insertedColumnList})");
71+
q.AppendLine($"INSERT INTO {target.QuotedTableName} ({insertedColumnList})");
7272

7373
if (columnList.Length != 0)
7474
{
@@ -85,7 +85,7 @@ public override string BuildMoveDataSql<T>(
8585

8686
if (options.CopyGeneratedColumns)
8787
{
88-
q.AppendLine($"SET IDENTITY_INSERT {target} OFF;");
88+
q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} OFF;");
8989
}
9090

9191
return q.ToString();

src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDbContextOptionsExtensions.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.EntityFrameworkCore;
2-
using Microsoft.EntityFrameworkCore.Infrastructure;
32

43
using PhenX.EntityFrameworkCore.BulkInsert.Extensions;
54

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,18 @@ internal abstract class BulkInsertProviderBase<TDialect>(ILogger<BulkInsertProvi
2929
protected async Task<string> CreateTableCopyAsync<T>(
3030
bool sync,
3131
DbContext context,
32+
BulkInsertOptions options,
3233
TableMetadata tableInfo,
3334
CancellationToken cancellationToken = default) where T : class
34-
{
35-
var tempTableName = await CreateTemporaryTableAsync(sync, context, tableInfo, cancellationToken);
36-
37-
await AddBulkInsertIdColumn<T>(sync, context, tempTableName, cancellationToken);
38-
39-
return tempTableName;
40-
}
41-
42-
private async Task<string> CreateTemporaryTableAsync(
43-
bool sync,
44-
DbContext context,
45-
TableMetadata tableInfo,
46-
CancellationToken cancellationToken)
4735
{
4836
var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName));
49-
var tempColumns = string.Join(", ", tableInfo.GetProperties(false).Select(x => x.QuotedColumName));
37+
var tempColumns = string.Join(", ", tableInfo.GetProperties(options.CopyGeneratedColumns).Select(x => x.QuotedColumName));
5038

5139
var query = string.Format(CreateTableCopySql, tempTableName, tableInfo.QuotedTableName, tempColumns);
5240

5341
await ExecuteAsync(sync, context, query, cancellationToken);
42+
await AddBulkInsertIdColumn<T>(sync, context, tempTableName, cancellationToken);
43+
5444
return tempTableName;
5545
}
5646

@@ -156,17 +146,43 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
156146
CancellationToken ctk = default
157147
) where T : class
158148
{
159-
var connectionInfo = await context.GetConnection(sync, ctk);
149+
List<T> result;
160150

161-
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
151+
var connectionInfo = await context.GetConnection(sync, ctk);
152+
try
153+
{
154+
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
162155

163-
var result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk);
156+
result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, cancellationToken: ctk);
164157

165-
await Finish(sync, connectionInfo, ctk);
158+
await Commit(sync, connectionInfo, ctk);
159+
}
160+
finally
161+
{
162+
await Finish(sync, connectionInfo, ctk);
163+
}
166164

167165
return result;
168166
}
169167

168+
private static async Task Commit(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk)
169+
{
170+
var (_, _, transaction, wasBegan) = connectionInfo;
171+
172+
if (!wasBegan)
173+
{
174+
if (sync)
175+
{
176+
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
177+
transaction.Commit();
178+
}
179+
else
180+
{
181+
await transaction.CommitAsync(ctk);
182+
}
183+
}
184+
}
185+
170186
private static async Task Finish(bool sync, ConnectionInfo connectionInfo, CancellationToken ctk)
171187
{
172188
var (connection, wasClosed, transaction, wasBegan) = connectionInfo;
@@ -176,12 +192,10 @@ private static async Task Finish(bool sync, ConnectionInfo connectionInfo, Cance
176192
if (sync)
177193
{
178194
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
179-
transaction.Commit();
180195
transaction.Dispose();
181196
}
182197
else
183198
{
184-
await transaction.CommitAsync(ctk);
185199
await transaction.DisposeAsync();
186200
}
187201
}
@@ -213,12 +227,17 @@ public virtual async Task BulkInsert<T>(
213227
if (onConflict != null)
214228
{
215229
var connectionInfo = await context.GetConnection(sync, ctk);
230+
try
231+
{
232+
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
216233

217-
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
218-
219-
await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, false, options, onConflict, ctk);
220-
221-
await Finish(sync, connectionInfo, ctk);
234+
await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, false, options, onConflict, ctk);
235+
await Commit(sync, connectionInfo, ctk);
236+
}
237+
finally
238+
{
239+
await Finish(sync, connectionInfo, ctk);
240+
}
222241
}
223242
else
224243
{
@@ -243,7 +262,7 @@ public virtual async Task BulkInsert<T>(
243262
var connectionInfo = await context.GetConnection(sync, ctk);
244263

245264
var tableName = tempTableRequired
246-
? await CreateTableCopyAsync<T>(sync, context, tableInfo, ctk)
265+
? await CreateTableCopyAsync<T>(sync, context, options, tableInfo, ctk)
247266
: tableInfo.QuotedTableName;
248267

249268
var properties = tableInfo.GetProperties(options.CopyGeneratedColumns);

src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ internal abstract class SqlDialectBuilder
2626
/// <typeparam name="T">Entity type</typeparam>
2727
/// <returns>The SQL query</returns>
2828
public virtual string BuildMoveDataSql<T>(
29-
TableMetadata source,
30-
string target,
29+
TableMetadata target,
30+
string source,
3131
IReadOnlyList<PropertyMetadata> insertedProperties,
3232
IReadOnlyList<PropertyMetadata> properties,
3333
BulkInsertOptions options, OnConflictOptions? onConflict = null)
@@ -40,22 +40,21 @@ public virtual string BuildMoveDataSql<T>(
4040

4141
var q = new StringBuilder();
4242

43-
var sourceName = source.QuotedTableName;
4443
if (SupportsMoveRows && options.MoveRows)
4544
{
4645
q.AppendLine($"""
4746
WITH moved_rows AS (
48-
DELETE FROM {source.QuotedTableName}
47+
DELETE FROM {source}
4948
RETURNING {insertedColumnList}
5049
)
5150
""");
52-
sourceName = "moved_rows";
51+
source = "moved_rows";
5352
}
5453

5554
q.AppendLine($"""
56-
INSERT INTO {target} ({insertedColumnList})
55+
INSERT INTO {target.QuotedTableName} ({insertedColumnList})
5756
SELECT {insertedColumnList}
58-
FROM {sourceName}
57+
FROM {source}
5958
WHERE TRUE
6059
""");
6160

@@ -68,13 +67,13 @@ WHERE TRUE
6867
if (onConflictTyped.Match != null)
6968
{
7069
q.Append(' ');
71-
AppendConflictMatch(q, GetColumns(source, onConflictTyped.Match));
70+
AppendConflictMatch(q, GetColumns(target, onConflictTyped.Match));
7271
}
7372

7473
if (onConflictTyped.Update != null)
7574
{
7675
q.Append(' ');
77-
AppendOnConflictUpdate(q, GetUpdates(source, insertedProperties, onConflictTyped.Update));
76+
AppendOnConflictUpdate(q, GetUpdates(target, insertedProperties, onConflictTyped.Update));
7877
}
7978

8079
if (onConflictTyped.Condition != null)

0 commit comments

Comments
 (0)