@@ -29,28 +29,18 @@ internal abstract class BulkInsertProviderBase<TDialect>(ILogger<BulkInsertProvi
2929 protected async Task < string > CreateTableCopyAsync < T > (
3030 bool sync ,
3131 DbContext context ,
32+ BulkInsertOptions options ,
3233 TableMetadata tableInfo ,
3334 CancellationToken cancellationToken = default ) where T : class
34- {
35- var tempTableName = await CreateTemporaryTableAsync ( sync , context , tableInfo , cancellationToken ) ;
36-
37- await AddBulkInsertIdColumn < T > ( sync , context , tempTableName , cancellationToken ) ;
38-
39- return tempTableName ;
40- }
41-
42- private async Task < string > CreateTemporaryTableAsync (
43- bool sync ,
44- DbContext context ,
45- TableMetadata tableInfo ,
46- CancellationToken cancellationToken )
4735 {
4836 var tempTableName = SqlDialect . QuoteTableName ( null , GetTempTableName ( tableInfo . TableName ) ) ;
49- var tempColumns = string . Join ( ", " , tableInfo . GetProperties ( false ) . Select ( x => x . QuotedColumName ) ) ;
37+ var tempColumns = string . Join ( ", " , tableInfo . GetProperties ( options . CopyGeneratedColumns ) . Select ( x => x . QuotedColumName ) ) ;
5038
5139 var query = string . Format ( CreateTableCopySql , tempTableName , tableInfo . QuotedTableName , tempColumns ) ;
5240
5341 await ExecuteAsync ( sync , context , query , cancellationToken ) ;
42+ await AddBulkInsertIdColumn < T > ( sync , context , tempTableName , cancellationToken ) ;
43+
5444 return tempTableName ;
5545 }
5646
@@ -156,17 +146,43 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
156146 CancellationToken ctk = default
157147 ) where T : class
158148 {
159- var connectionInfo = await context . GetConnection ( sync , ctk ) ;
149+ List < T > result ;
160150
161- var ( tableName , _) = await PerformBulkInsertAsync ( sync , context , tableInfo , entities , options , tempTableRequired : true , ctk : ctk ) ;
151+ var connectionInfo = await context . GetConnection ( sync , ctk ) ;
152+ try
153+ {
154+ var ( tableName , _) = await PerformBulkInsertAsync ( sync , context , tableInfo , entities , options , tempTableRequired : true , ctk : ctk ) ;
162155
163- var result = await CopyFromTempTableAsync < T > ( sync , context , tableInfo , tableName , true , options , onConflict , cancellationToken : ctk ) ;
156+ result = await CopyFromTempTableAsync < T > ( sync , context , tableInfo , tableName , true , options , onConflict , cancellationToken : ctk ) ;
164157
165- await Finish ( sync , connectionInfo , ctk ) ;
158+ await Commit ( sync , connectionInfo , ctk ) ;
159+ }
160+ finally
161+ {
162+ await Finish ( sync , connectionInfo , ctk ) ;
163+ }
166164
167165 return result ;
168166 }
169167
168+ private static async Task Commit ( bool sync , ConnectionInfo connectionInfo , CancellationToken ctk )
169+ {
170+ var ( _, _, transaction , wasBegan ) = connectionInfo ;
171+
172+ if ( ! wasBegan )
173+ {
174+ if ( sync )
175+ {
176+ // ReSharper disable once MethodHasAsyncOverloadWithCancellation
177+ transaction . Commit ( ) ;
178+ }
179+ else
180+ {
181+ await transaction . CommitAsync ( ctk ) ;
182+ }
183+ }
184+ }
185+
170186 private static async Task Finish ( bool sync , ConnectionInfo connectionInfo , CancellationToken ctk )
171187 {
172188 var ( connection , wasClosed , transaction , wasBegan ) = connectionInfo ;
@@ -176,12 +192,10 @@ private static async Task Finish(bool sync, ConnectionInfo connectionInfo, Cance
176192 if ( sync )
177193 {
178194 // ReSharper disable once MethodHasAsyncOverloadWithCancellation
179- transaction . Commit ( ) ;
180195 transaction . Dispose ( ) ;
181196 }
182197 else
183198 {
184- await transaction . CommitAsync ( ctk ) ;
185199 await transaction . DisposeAsync ( ) ;
186200 }
187201 }
@@ -213,12 +227,17 @@ public virtual async Task BulkInsert<T>(
213227 if ( onConflict != null )
214228 {
215229 var connectionInfo = await context . GetConnection ( sync , ctk ) ;
230+ try
231+ {
232+ var ( tableName , _) = await PerformBulkInsertAsync ( sync , context , tableInfo , entities , options , tempTableRequired : true , ctk : ctk ) ;
216233
217- var ( tableName , _) = await PerformBulkInsertAsync ( sync , context , tableInfo , entities , options , tempTableRequired : true , ctk : ctk ) ;
218-
219- await CopyFromTempTableAsync < T > ( sync , context , tableInfo , tableName , false , options , onConflict , ctk ) ;
220-
221- await Finish ( sync , connectionInfo , ctk ) ;
234+ await CopyFromTempTableAsync < T > ( sync , context , tableInfo , tableName , false , options , onConflict , ctk ) ;
235+ await Commit ( sync , connectionInfo , ctk ) ;
236+ }
237+ finally
238+ {
239+ await Finish ( sync , connectionInfo , ctk ) ;
240+ }
222241 }
223242 else
224243 {
@@ -243,7 +262,7 @@ public virtual async Task BulkInsert<T>(
243262 var connectionInfo = await context . GetConnection ( sync , ctk ) ;
244263
245264 var tableName = tempTableRequired
246- ? await CreateTableCopyAsync < T > ( sync , context , tableInfo , ctk )
265+ ? await CreateTableCopyAsync < T > ( sync , context , options , tableInfo , ctk )
247266 : tableInfo . QuotedTableName ;
248267
249268 var properties = tableInfo . GetProperties ( options . CopyGeneratedColumns ) ;
0 commit comments