Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using PhenX.EntityFrameworkCore.BulkInsert.Options;

namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;

/// <summary>
/// Options specific to MySQL bulk insert.
/// </summary>
public class MySqlBulkInsertOptions : BulkInsertOptions
{

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;

internal class MySqlBulkInsertProvider : BulkInsertProviderBase<MySqlServerDialectBuilder>
internal class MySqlBulkInsertProvider : BulkInsertProviderBase<MySqlServerDialectBuilder, MySqlBulkInsertOptions>
{
public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null) : base(logger)
{
Expand All @@ -25,6 +25,9 @@ public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null)
/// <inheritdoc />
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";

/// <inheritdoc />
protected override MySqlBulkInsertOptions GetDefaultOptions() => new();
Comment thread
SebastianStehle marked this conversation as resolved.
Outdated

/// <inheritdoc />
public override Task<List<T>> BulkInsertReturnEntities<T>(
bool sync,
Expand All @@ -44,21 +47,25 @@ protected override async Task BulkInsert<T>(
IEnumerable<T> entities,
string tableName,
PropertyAccessor[] properties,
BulkInsertOptions options,
MySqlBulkInsertOptions options,
CancellationToken ctk
)
{
var connection = (MySqlConnection)context.Database.GetDbConnection();
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction()

var sqlTransaction = context.Database.CurrentTransaction?.GetDbTransaction()
?? throw new InvalidOperationException("No open transaction found.");

if (sqlTransaction is not MySqlTransaction mySqlTransaction)
{
throw new InvalidOperationException($"Invalid transaction foud, got {sqlTransaction.GetType()}.");
}

var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction);
bulkCopy.DestinationTableName = tableName;
bulkCopy.BulkCopyTimeout = options.GetCopyTimeoutInSeconds();
var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction)
{
DestinationTableName = tableName,
BulkCopyTimeout = options.GetCopyTimeoutInSeconds(),
};

var sourceOrdinal = 0;
foreach (var prop in properties)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql;

[UsedImplicitly]
internal class PostgreSqlBulkInsertProvider : BulkInsertProviderBase<PostgreSqlDialectBuilder>
internal class PostgreSqlBulkInsertProvider : BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>
{
public PostgreSqlBulkInsertProvider(ILogger<PostgreSqlBulkInsertProvider>? logger = null) : base(logger)
{
Expand All @@ -31,6 +31,12 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string
return $"COPY {tableName} ({string.Join(", ", columns)}) FROM STDIN (FORMAT BINARY)";
}

/// <inheritdoc />
protected override BulkInsertOptions GetDefaultOptions() => new()
{
BatchSize = 50_000,
};

/// <inheritdoc />
protected override async Task BulkInsert<T>(
bool sync,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using Microsoft.Data.SqlClient;

using PhenX.EntityFrameworkCore.BulkInsert.Options;

namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;

/// <summary>
/// Options specific to SQL Server bulk insert.
/// </summary>
public class SqlServerBulkInsertOptions : BulkInsertOptions
{
/// <inheritdoc cref="SqlBulkCopyOptions"/>
public SqlBulkCopyOptions CopyOptions { get; set; } = SqlBulkCopyOptions.Default;

/// <inheritdoc cref="SqlBulkCopy.EnableStreaming"/>
public bool EnableStreaming { get; set; } = false;

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;

[UsedImplicitly]
internal class SqlServerBulkInsertProvider : BulkInsertProviderBase<SqlServerDialectBuilder>
internal class SqlServerBulkInsertProvider : BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>
{
public SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger = null) : base(logger)
{
Expand All @@ -27,24 +27,31 @@ public SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger
/// <inheritdoc />
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";

protected override SqlServerBulkInsertOptions GetDefaultOptions() => new()
{
BatchSize = 50_000,
};

/// <inheritdoc />
protected override async Task BulkInsert<T>(
bool sync,
DbContext context,
IEnumerable<T> entities,
string tableName,
PropertyAccessor[] properties,
BulkInsertOptions options,
SqlServerBulkInsertOptions options,
CancellationToken ctk
)
{
var connection = (SqlConnection) context.Database.GetDbConnection();
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction() as SqlTransaction;

using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.TableLock, sqlTransaction);
using var bulkCopy = new SqlBulkCopy(connection, options.CopyOptions, sqlTransaction);

bulkCopy.DestinationTableName = tableName;
bulkCopy.BatchSize = options.BatchSize ?? 50_000;
bulkCopy.BatchSize = options.BatchSize;
bulkCopy.BulkCopyTimeout = options.GetCopyTimeoutInSeconds();
bulkCopy.EnableStreaming = options.EnableStreaming;

foreach (var prop in properties)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite;

[UsedImplicitly]
internal class SqliteBulkInsertProvider : BulkInsertProviderBase<SqliteDialectBuilder>
internal class SqliteBulkInsertProvider : BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>
{
public SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger = null) : base(logger)
{
Expand All @@ -29,6 +29,12 @@ public SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger = null
/// <inheritdoc />
protected override string AddTableCopyBulkInsertId => "--"; // No need to add an ID column in SQLite

/// <inheritdoc />
protected override BulkInsertOptions GetDefaultOptions() => new()
{
BatchSize = 5,
};

/// <inheritdoc />
protected override Task AddBulkInsertIdColumn<T>(
bool sync,
Expand Down Expand Up @@ -134,7 +140,7 @@ CancellationToken ctk
) where T : class
{
const int maxParams = 1000;
var batchSize = options.BatchSize ?? 5;
var batchSize = options.BatchSize;
batchSize = Math.Min(batchSize, maxParams / properties.Length);

await using var insertCommand = GetInsertCommand(context, typeof(T), tableName, options, batchSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ internal Task<List<T>> BulkInsertReturnEntities<T>(
BulkInsertOptions options,
OnConflictOptions? onConflict = null,
CancellationToken ctk = default
) where T : class;
)
where T : class;

/// <summary>
/// Calls the provider to perform a bulk insert operation without returning the inserted entities.
Expand All @@ -31,5 +32,11 @@ internal Task BulkInsert<T>(
BulkInsertOptions options,
OnConflictOptions? onConflict = null,
CancellationToken ctk = default
) where T : class;
)
where T : class;

/// <summary>
/// Make the default options for the provider, can be a subclass of <see cref="BulkInsertOptions"/>.
/// </summary>
internal BulkInsertOptions InternalGetDefaultOptions();
}
35 changes: 25 additions & 10 deletions src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@

namespace PhenX.EntityFrameworkCore.BulkInsert;

internal abstract class BulkInsertProviderBase<TDialect> : IBulkInsertProvider
internal abstract class BulkInsertProviderBase<TDialect, TOptions> : IBulkInsertProvider
where TDialect : SqlDialectBuilder, new()
where TOptions : BulkInsertOptions, new()
{
protected readonly TDialect SqlDialect = new();
private readonly ILogger<BulkInsertProviderBase<TDialect>>? Logger;
private readonly ILogger<BulkInsertProviderBase<TDialect, TOptions>>? Logger;

protected virtual string BulkInsertId => "_bulk_insert_id";

protected abstract string CreateTableCopySql { get; }
protected abstract string AddTableCopyBulkInsertId { get; }

protected BulkInsertProviderBase(ILogger<BulkInsertProviderBase<TDialect>>? logger = null)
protected BulkInsertProviderBase(ILogger<BulkInsertProviderBase<TDialect, TOptions>>? logger = null)
{
Logger = logger;
}

protected async Task<string> CreateTableCopyAsync<T>(
bool sync,
DbContext context,
BulkInsertOptions options,
TOptions options,
CancellationToken cancellationToken = default) where T : class
{
var tableInfo = GetTableInfo(context, typeof(T));
Expand Down Expand Up @@ -148,11 +149,16 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
CancellationToken ctk = default
) where T : class
{
if (options is not TOptions providerOptions)
Comment thread
SebastianStehle marked this conversation as resolved.
{
throw new InvalidOperationException($"Invalid options type: {options.GetType().Name}. Expected: {typeof(TOptions).Name}");
}

var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk);

var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk);
var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, providerOptions, tempTableRequired: true, ctk: ctk);

var result = await CopyFromTempTableAsync<T>(sync, context, tableName, true, options, onConflict, cancellationToken: ctk);
var result = await CopyFromTempTableAsync<T>(sync, context, tableName, true, providerOptions, onConflict, cancellationToken: ctk);

await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk);

Expand Down Expand Up @@ -200,27 +206,36 @@ public virtual async Task BulkInsert<T>(
CancellationToken ctk = default
) where T : class
{
if (options is not TOptions providerOptions)
{
throw new InvalidOperationException($"Invalid options type: {options.GetType().Name}. Expected: {typeof(TOptions).Name}");
}

if (onConflict != null)
{
var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk);

var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk);
var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, providerOptions, tempTableRequired: true, ctk: ctk);

await CopyFromTempTableAsync<T>(sync, context, tableName, false, options, onConflict, ctk);

await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk);
}
else
{
await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: false, ctk: ctk);
await PerformBulkInsertAsync(sync, context, entities, providerOptions, tempTableRequired: false, ctk: ctk);
}
}

public BulkInsertOptions InternalGetDefaultOptions() => GetDefaultOptions();

protected abstract TOptions GetDefaultOptions();

private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync<T>(
bool sync,
DbContext context,
IEnumerable<T> entities,
BulkInsertOptions options,
TOptions options,
bool tempTableRequired,
CancellationToken ctk = default) where T : class
{
Expand Down Expand Up @@ -256,7 +271,7 @@ protected abstract Task BulkInsert<T>(
IEnumerable<T> entities,
string tableName,
PropertyAccessor[] properties,
BulkInsertOptions options,
TOptions options,
CancellationToken ctk
) where T : class;

Expand Down
27 changes: 27 additions & 0 deletions src/PhenX.EntityFrameworkCore.BulkInsert/Enums/ProviderType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.Enums;

/// <summary>
/// Enumeration of supported database providers.
/// </summary>
public enum ProviderType
{
/// <summary>
/// SQL Server provider.
/// </summary>
SqlServer,

/// <summary>
/// PostgreSQL provider.
/// </summary>
PostgreSql,

/// <summary>
/// SQLite provider.
/// </summary>
Sqlite,

/// <summary>
/// MySQL provider.
/// </summary>
MySql,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Storage;

using PhenX.EntityFrameworkCore.BulkInsert.Enums;

namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;

internal static class DbContextExtensions
Expand Down Expand Up @@ -61,4 +63,17 @@ internal static IProperty[] GetProperties(this DbContext context, Type entityTyp

return (connection, wasClosed, transaction, wasBegan);
}

/// <summary>
/// Tells if the current provider is the specified provider type.
/// </summary>
internal static bool IsProvider(this DbContext context, ProviderType providerType)
{
if (context.Database.ProviderName == null)
{
throw new InvalidOperationException("Database provider name is null.");
}

return context.Database.ProviderName.Contains(providerType.ToString(), StringComparison.OrdinalIgnoreCase);
Comment thread
SebastianStehle marked this conversation as resolved.
}
}
Loading