11using System . Runtime . CompilerServices ;
2+ using System . Runtime . InteropServices ;
23
34using Microsoft . EntityFrameworkCore ;
45using Microsoft . EntityFrameworkCore . Storage ;
@@ -43,14 +44,20 @@ protected override async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
4344 }
4445
4546 var tableName = await PerformBulkInsertAsync ( sync , context , tableInfo , entities , options , tempTableRequired : true , ctk : ctk ) ;
47+ try
48+ {
49+ var result =
50+ await CopyFromTempTableAsync < T , T > ( sync , context , tableInfo , tableName , true , options , onConflict , ctk : ctk )
51+ ?? throw new InvalidOperationException ( "Copy returns null enumerable." ) ;
4652
47- var result =
48- await CopyFromTempTableAsync < T , T > ( sync , context , tableInfo , tableName , true , options , onConflict , ctk : ctk )
49- ?? throw new InvalidOperationException ( "Copy returns null enumerable." ) ;
50-
51- await foreach ( var item in result . WithCancellation ( ctk ) )
53+ await foreach ( var item in result . WithCancellation ( ctk ) )
54+ {
55+ yield return item ;
56+ }
57+ }
58+ finally
5259 {
53- yield return item ;
60+ await PerformDropTempTableAsync ( sync , context , tableName ) ;
5461 }
5562
5663 // Commit the transaction if we own them.
@@ -71,6 +78,11 @@ protected override async Task BulkInsert<T>(
7178 OnConflictOptions < T > ? onConflict ,
7279 CancellationToken ctk ) where T : class
7380 {
81+ if ( entities . TryGetNonEnumeratedCount ( out var count ) && count == 0 )
82+ {
83+ throw new InvalidOperationException ( "No entities to insert." ) ;
84+ }
85+
7486 using var activity = Telemetry . ActivitySource . StartActivity ( "BulkInsert" ) ;
7587 activity ? . AddTag ( "tableName" , tableInfo . TableName ) ;
7688 activity ? . AddTag ( "synchronous" , sync ) ;
@@ -86,8 +98,14 @@ protected override async Task BulkInsert<T>(
8698 }
8799
88100 var tableName = await PerformBulkInsertAsync ( sync , context , tableInfo , entities , options , tempTableRequired : true , ctk : ctk ) ;
89-
90- await CopyFromTempTableAsync < T , T > ( sync , context , tableInfo , tableName , false , options , onConflict , ctk ) ;
101+ try
102+ {
103+ await CopyFromTempTableAsync < T , T > ( sync , context , tableInfo , tableName , false , options , onConflict , ctk ) ;
104+ }
105+ finally
106+ {
107+ await PerformDropTempTableAsync ( sync , context , tableName ) ;
108+ }
91109 }
92110 else
93111 {
@@ -107,7 +125,6 @@ protected override async Task BulkInsert<T>(
107125 await connection . Close ( sync , ctk ) ;
108126 }
109127 }
110-
111128 private async Task < string > PerformBulkInsertAsync < T > (
112129 bool sync ,
113130 DbContext context ,
@@ -117,11 +134,6 @@ private async Task<string> PerformBulkInsertAsync<T>(
117134 bool tempTableRequired ,
118135 CancellationToken ctk ) where T : class
119136 {
120- if ( entities . TryGetNonEnumeratedCount ( out var count ) && count == 0 )
121- {
122- throw new InvalidOperationException ( "No entities to insert." ) ;
123- }
124-
125137 var tableName = tempTableRequired
126138 ? await CreateTableCopyAsync < T > ( sync , context , options , tableInfo , ctk )
127139 : tableInfo . QuotedTableName ;
@@ -208,6 +220,33 @@ protected virtual async Task AddBulkInsertIdColumn<T>(
208220 return null ;
209221 }
210222
223+ private async Task PerformDropTempTableAsync ( bool sync , DbContext dbContext , string tableName )
224+ {
225+ try
226+ {
227+ await DropTempTableAsync ( sync , dbContext , tableName ) ;
228+ }
229+ catch ( Exception ex )
230+ {
231+ // The drop operation is not mandatory, therefore never fail the actual operation.
232+ if ( logger != null )
233+ {
234+ Log . DropTemporaryTableFailed ( logger , ex ) ;
235+ }
236+ }
237+ }
238+
239+ /// <summary>
240+ /// Drops the temporary table manually if needed.
241+ /// </summary>
242+ /// <param name="sync">Indicates if the operation is synchronous.</param>
243+ /// <param name="dbContext">The context.</param>
244+ /// <param name="tableName">The table name.</param>
245+ protected virtual Task DropTempTableAsync ( bool sync , DbContext dbContext , string tableName )
246+ {
247+ return Task . CompletedTask ;
248+ }
249+
211250 protected static async Task ExecuteAsync (
212251 bool sync ,
213252 DbContext context ,
0 commit comments