Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 126 additions & 145 deletions src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using System.Data.Common;
using System.Runtime.CompilerServices;

using Microsoft.EntityFrameworkCore;
Expand All @@ -13,10 +12,7 @@

namespace PhenX.EntityFrameworkCore.BulkInsert;

#pragma warning disable CS9113 // Parameter is unread.
internal abstract class BulkInsertProviderBase<TDialect>(ILogger<BulkInsertProviderBase<TDialect>>? logger = null) : IBulkInsertProvider
#pragma warning restore CS9113 // Parameter is unread.
where TDialect : SqlDialectBuilder, new()
internal abstract class BulkInsertProviderBase<TDialect>(ILogger<BulkInsertProviderBase<TDialect>>? logger = null) : IBulkInsertProvider where TDialect : SqlDialectBuilder, new()
{
protected readonly TDialect SqlDialect = new();

Expand All @@ -28,116 +24,6 @@ internal abstract class BulkInsertProviderBase<TDialect>(ILogger<BulkInsertProvi

SqlDialectBuilder IBulkInsertProvider.SqlDialect => SqlDialect;

protected async Task<string> CreateTableCopyAsync<T>(
bool sync,
DbContext context,
BulkInsertOptions options,
TableMetadata tableInfo,
CancellationToken ctk) where T : class
{
var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName));
var tempColumns = tableInfo.GetProperties(options.CopyGeneratedColumns);

var query = SqlDialect.CreateTableCopySql(tempTableName, tableInfo, tempColumns);

await ExecuteAsync(sync, context, query, ctk);
await AddBulkInsertIdColumn<T>(sync, context, tempTableName, ctk);

return tempTableName;
}

protected virtual async Task AddBulkInsertIdColumn<T>(
bool sync,
DbContext context,
string tempTableName,
CancellationToken ctk) where T : class
{
var alterQuery = string.Format(AddTableCopyBulkInsertId, tempTableName);

await ExecuteAsync(sync, context, alterQuery, ctk);
}

protected static async Task ExecuteAsync(
bool sync,
DbContext context,
string query,
CancellationToken ctk)
{
var command = context.Database.GetDbConnection().CreateCommand();
command.Transaction = context.Database.CurrentTransaction!.GetDbTransaction();
command.CommandText = query;

if (sync)
{
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
command.ExecuteNonQuery();
}
else
{
await command.ExecuteNonQueryAsync(ctk);
}
}

public async Task<IAsyncEnumerable<T>?> CopyFromTempTableAsync<T>(
bool sync,
DbContext context,
TableMetadata tableInfo,
string tempTableName,
bool returnData,
BulkInsertOptions options,
OnConflictOptions? onConflict,
CancellationToken ctk) where T : class
{
return await CopyFromTempTableWithoutKeysAsync<T, T>(
sync,
context,
tableInfo,
tempTableName,
returnData,
options,
onConflict,
ctk);
}

private async Task<IAsyncEnumerable<TResult>?> CopyFromTempTableWithoutKeysAsync<T, TResult>(
bool sync,
DbContext context,
TableMetadata tableInfo,
string tempTableName,
bool returnData,
BulkInsertOptions options,
OnConflictOptions? onConflict,
CancellationToken ctk) where T : class where TResult : class
{
var query =
SqlDialect.BuildMoveDataSql<T>(
tableInfo,
tempTableName,
tableInfo.GetProperties(options.CopyGeneratedColumns),
returnData ? tableInfo.GetProperties() : [],
options,
onConflict);

if (returnData)
{
return Query(context, query);
}

// If not returning data, just execute the command
await ExecuteAsync(sync, context, query, ctk);
return null;

static IAsyncEnumerable<TResult> Query(DbContext context, string query)
{
// Use EF to execute the query and return the results
IQueryable<TResult> queryable = context
.Set<TResult>()
.FromSqlRaw(query);

return queryable.AsAsyncEnumerable();
}
}

public virtual async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
bool sync,
DbContext context,
Expand All @@ -147,15 +33,25 @@ public virtual async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
OnConflictOptions? onConflict,
[EnumeratorCancellation] CancellationToken ctk) where T : class
{
using var activity = Telemetry.ActivitySource.StartActivity("BulkInsertReturnEntities");
activity?.AddTag("tableName", tableInfo.TableName);
activity?.AddTag("synchronous", sync);

var connection = await context.GetConnection(sync, ctk);
try
{
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
if (logger != null)
{
Log.UsingTempTablToReturnData(logger);
}

var result = await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk)
?? throw new InvalidOperationException("Failed to get async enumerable.");
var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);

await foreach (var item in result)
var result =
await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk)
?? throw new InvalidOperationException("Copy returns null enumerable.");

await foreach (var item in result.WithCancellation(ctk))
{
yield return item;
}
Expand All @@ -178,30 +74,44 @@ public virtual async Task BulkInsert<T>(
OnConflictOptions? onConflict,
CancellationToken ctk) where T : class
{
if (onConflict != null)
using var activity = Telemetry.ActivitySource.StartActivity("BulkInsert");
activity?.AddTag("tableName", tableInfo.TableName);
activity?.AddTag("synchronous", sync);

var connection = await context.GetConnection(sync, ctk);
try
{
var connection = await context.GetConnection(sync, ctk);
try
if (onConflict != null)
{
var (tableName, _) = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
if (logger != null)
{
Log.UsingTempTableToResolveConflicts(logger);
}

await CopyFromTempTableAsync<T>(sync, context, tableInfo, tableName, false, options, onConflict, ctk);
var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);

// Commit the transaction if we own them.
await connection.Commit(sync, ctk);
await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, false, options, onConflict, ctk);
}
finally
else
{
await connection.Close(sync, ctk);
if (logger != null)
{
Log.UsingDirectInsert(logger);
}

await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: false, ctk: ctk);
}

// Commit the transaction if we own them.
await connection.Commit(sync, ctk);
}
else
finally
{
await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: false, ctk: ctk);
await connection.Close(sync, ctk);
}
}

private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync<T>(
private async Task<string> PerformBulkInsertAsync<T>(
bool sync,
DbContext context,
TableMetadata tableInfo,
Expand All @@ -215,27 +125,18 @@ public virtual async Task BulkInsert<T>(
throw new InvalidOperationException("No entities to insert.");
}

var connection = await context.GetConnection(sync, ctk);

var tableName = tempTableRequired
? await CreateTableCopyAsync<T>(sync, context, options, tableInfo, ctk)
: tableInfo.QuotedTableName;

var properties = tableInfo.GetProperties(options.CopyGeneratedColumns);

try
{
await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk);
using var activity = Telemetry.ActivitySource.StartActivity("Insert");
activity?.AddTag("tempTable", tempTableRequired);
activity?.AddTag("synchronous", sync);

// Commit the transaction if we own them.
await connection.Commit(sync, ctk);
}
finally
{
await connection.Close(sync, ctk);
}

return (tableName, connection.Connection);
await BulkInsert(false, context, tableInfo, entities, tableName, properties, options, ctk);
return tableName;
}

/// <summary>
Expand All @@ -250,4 +151,84 @@ protected abstract Task BulkInsert<T>(
IReadOnlyList<PropertyMetadata> properties,
BulkInsertOptions options,
CancellationToken ctk) where T : class;

protected async Task<string> CreateTableCopyAsync<T>(
bool sync,
DbContext context,
BulkInsertOptions options,
TableMetadata tableInfo,
CancellationToken ctk) where T : class
{
var tempTableName = SqlDialect.QuoteTableName(null, GetTempTableName(tableInfo.TableName));
var tempColumns = tableInfo.GetProperties(options.CopyGeneratedColumns);

var query = SqlDialect.CreateTableCopySql(tempTableName, tableInfo, tempColumns);

await ExecuteAsync(sync, context, query, ctk);
await AddBulkInsertIdColumn<T>(sync, context, tempTableName, ctk);

return tempTableName;
}

protected virtual async Task AddBulkInsertIdColumn<T>(
bool sync,
DbContext context,
string tempTableName,
CancellationToken ctk) where T : class
{
var alterQuery = string.Format(AddTableCopyBulkInsertId, tempTableName);

await ExecuteAsync(sync, context, alterQuery, ctk);
}

private async Task<IAsyncEnumerable<TResult>?> CopyFromTempTableAsync<T, TResult>(
bool sync,
DbContext context,
TableMetadata tableInfo,
string tempTableName,
bool returnData,
BulkInsertOptions options,
OnConflictOptions? onConflict,
CancellationToken ctk) where T : class where TResult : class
{
var query =
SqlDialect.BuildMoveDataSql<T>(
tableInfo,
tempTableName,
tableInfo.GetProperties(options.CopyGeneratedColumns),
returnData ? tableInfo.GetProperties() : [],
options,
onConflict);

if (returnData)
{
// Use EF to execute the query and return the results
return context.Set<TResult>().FromSqlRaw(query).AsAsyncEnumerable();
}

// If not returning data, just execute the command
await ExecuteAsync(sync, context, query, ctk);
return null;
}

protected static async Task ExecuteAsync(
bool sync,
DbContext context,
string query,
CancellationToken ctk)
{
var command = context.Database.GetDbConnection().CreateCommand();
command.Transaction = context.Database.CurrentTransaction!.GetDbTransaction();
command.CommandText = query;

if (sync)
{
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
command.ExecuteNonQuery();
}
else
{
await command.ExecuteNonQueryAsync(ctk);
}
}
}
Loading