Skip to content

Commit f037a4c

Browse files
committed
Add support for provider-specifig options
1 parent b72d189 commit f037a4c

13 files changed

Lines changed: 171 additions & 53 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using PhenX.EntityFrameworkCore.BulkInsert.Options;
2+
3+
namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;
4+
5+
/// <summary>
6+
/// Options specific to MySQL bulk insert.
7+
/// </summary>
8+
public class MySqlBulkInsertOptions : BulkInsertOptions
9+
{
10+
11+
}

src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;
1010

11-
internal class MySqlBulkInsertProvider : BulkInsertProviderBase<MySqlServerDialectBuilder>
11+
internal class MySqlBulkInsertProvider : BulkInsertProviderBase<MySqlServerDialectBuilder, MySqlBulkInsertOptions>
1212
{
1313
public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null) : base(logger)
1414
{
@@ -26,7 +26,7 @@ public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null)
2626
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";
2727

2828
/// <inheritdoc />
29-
public override BulkInsertOptions GetDefaultOptions() => new();
29+
protected override MySqlBulkInsertOptions GetDefaultOptions() => new();
3030

3131
/// <inheritdoc />
3232
public override Task<List<T>> BulkInsertReturnEntities<T>(
@@ -47,21 +47,25 @@ protected override async Task BulkInsert<T>(
4747
IEnumerable<T> entities,
4848
string tableName,
4949
PropertyAccessor[] properties,
50-
BulkInsertOptions options,
50+
MySqlBulkInsertOptions options,
5151
CancellationToken ctk
5252
)
5353
{
5454
var connection = (MySqlConnection)context.Database.GetDbConnection();
55-
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction()
55+
56+
var sqlTransaction = context.Database.CurrentTransaction?.GetDbTransaction()
5657
?? throw new InvalidOperationException("No open transaction found.");
58+
5759
if (sqlTransaction is not MySqlTransaction mySqlTransaction)
5860
{
5961
throw new InvalidOperationException($"Invalid transaction foud, got {sqlTransaction.GetType()}.");
6062
}
6163

62-
var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction);
63-
bulkCopy.DestinationTableName = tableName;
64-
bulkCopy.BulkCopyTimeout = options.GetCopyTimeoutInSeconds();
64+
var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction)
65+
{
66+
DestinationTableName = tableName,
67+
BulkCopyTimeout = options.GetCopyTimeoutInSeconds(),
68+
};
6569

6670
var sourceOrdinal = 0;
6771
foreach (var prop in properties)

src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql;
1111

1212
[UsedImplicitly]
13-
internal class PostgreSqlBulkInsertProvider : BulkInsertProviderBase<PostgreSqlDialectBuilder>
13+
internal class PostgreSqlBulkInsertProvider : BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>
1414
{
1515
public PostgreSqlBulkInsertProvider(ILogger<PostgreSqlBulkInsertProvider>? logger = null) : base(logger)
1616
{
@@ -32,7 +32,7 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string
3232
}
3333

3434
/// <inheritdoc />
35-
public override BulkInsertOptions GetDefaultOptions() => new()
35+
protected override BulkInsertOptions GetDefaultOptions() => new()
3636
{
3737
BatchSize = 50_000,
3838
};

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertOptions.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;
1010
public class SqlServerBulkInsertOptions : BulkInsertOptions
1111
{
1212
/// <inheritdoc cref="SqlBulkCopyOptions"/>
13-
public SqlBulkCopyOptions CopyOptions { get; init; } = SqlBulkCopyOptions.Default;
13+
public SqlBulkCopyOptions CopyOptions { get; set; } = SqlBulkCopyOptions.Default;
14+
15+
/// <inheritdoc cref="SqlBulkCopy.EnableStreaming"/>
16+
public bool EnableStreaming { get; set; } = false;
1417

1518
}

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;
1111

1212
[UsedImplicitly]
13-
internal class SqlServerBulkInsertProvider : BulkInsertProviderBase<SqlServerDialectBuilder>
13+
internal class SqlServerBulkInsertProvider : BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>
1414
{
1515
public SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger = null) : base(logger)
1616
{
@@ -27,7 +27,7 @@ public SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger
2727
/// <inheritdoc />
2828
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";
2929

30-
public override BulkInsertOptions GetDefaultOptions() => new SqlServerBulkInsertOptions
30+
protected override SqlServerBulkInsertOptions GetDefaultOptions() => new()
3131
{
3232
BatchSize = 50_000,
3333
};
@@ -39,17 +39,19 @@ protected override async Task BulkInsert<T>(
3939
IEnumerable<T> entities,
4040
string tableName,
4141
PropertyAccessor[] properties,
42-
BulkInsertOptions options,
42+
SqlServerBulkInsertOptions options,
4343
CancellationToken ctk
4444
)
4545
{
4646
var connection = (SqlConnection) context.Database.GetDbConnection();
4747
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction() as SqlTransaction;
4848

49-
using var bulkCopy = new SqlBulkCopy(connection, SqlBulkCopyOptions.TableLock, sqlTransaction);
49+
using var bulkCopy = new SqlBulkCopy(connection, options.CopyOptions, sqlTransaction);
50+
5051
bulkCopy.DestinationTableName = tableName;
5152
bulkCopy.BatchSize = options.BatchSize;
5253
bulkCopy.BulkCopyTimeout = options.GetCopyTimeoutInSeconds();
54+
bulkCopy.EnableStreaming = options.EnableStreaming;
5355

5456
foreach (var prop in properties)
5557
{

src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite;
1313

1414
[UsedImplicitly]
15-
internal class SqliteBulkInsertProvider : BulkInsertProviderBase<SqliteDialectBuilder>
15+
internal class SqliteBulkInsertProvider : BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>
1616
{
1717
public SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger = null) : base(logger)
1818
{
@@ -30,7 +30,7 @@ public SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger = null
3030
protected override string AddTableCopyBulkInsertId => "--"; // No need to add an ID column in SQLite
3131

3232
/// <inheritdoc />
33-
public override BulkInsertOptions GetDefaultOptions() => new()
33+
protected override BulkInsertOptions GetDefaultOptions() => new()
3434
{
3535
BatchSize = 5,
3636
};

src/PhenX.EntityFrameworkCore.BulkInsert/Abstractions/IBulkInsertProvider.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ internal Task<List<T>> BulkInsertReturnEntities<T>(
1919
BulkInsertOptions options,
2020
OnConflictOptions? onConflict = null,
2121
CancellationToken ctk = default
22-
) where T : class;
22+
)
23+
where T : class;
2324

2425
/// <summary>
2526
/// Calls the provider to perform a bulk insert operation without returning the inserted entities.
@@ -31,10 +32,11 @@ internal Task BulkInsert<T>(
3132
BulkInsertOptions options,
3233
OnConflictOptions? onConflict = null,
3334
CancellationToken ctk = default
34-
) where T : class;
35+
)
36+
where T : class;
3537

3638
/// <summary>
3739
/// Make the default options for the provider, can be a subclass of <see cref="BulkInsertOptions"/>.
3840
/// </summary>
39-
internal BulkInsertOptions GetDefaultOptions();
41+
internal BulkInsertOptions InternalGetDefaultOptions();
4042
}

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,27 @@
1212

1313
namespace PhenX.EntityFrameworkCore.BulkInsert;
1414

15-
internal abstract class BulkInsertProviderBase<TDialect> : IBulkInsertProvider
15+
internal abstract class BulkInsertProviderBase<TDialect, TOptions> : IBulkInsertProvider
1616
where TDialect : SqlDialectBuilder, new()
17+
where TOptions : BulkInsertOptions, new()
1718
{
1819
protected readonly TDialect SqlDialect = new();
19-
private readonly ILogger<BulkInsertProviderBase<TDialect>>? Logger;
20+
private readonly ILogger<BulkInsertProviderBase<TDialect, TOptions>>? Logger;
2021

2122
protected virtual string BulkInsertId => "_bulk_insert_id";
2223

2324
protected abstract string CreateTableCopySql { get; }
2425
protected abstract string AddTableCopyBulkInsertId { get; }
2526

26-
protected BulkInsertProviderBase(ILogger<BulkInsertProviderBase<TDialect>>? logger = null)
27+
protected BulkInsertProviderBase(ILogger<BulkInsertProviderBase<TDialect, TOptions>>? logger = null)
2728
{
2829
Logger = logger;
2930
}
3031

3132
protected async Task<string> CreateTableCopyAsync<T>(
3233
bool sync,
3334
DbContext context,
34-
BulkInsertOptions options,
35+
TOptions options,
3536
CancellationToken cancellationToken = default) where T : class
3637
{
3738
var tableInfo = GetTableInfo(context, typeof(T));
@@ -148,11 +149,16 @@ public virtual async Task<List<T>> BulkInsertReturnEntities<T>(
148149
CancellationToken ctk = default
149150
) where T : class
150151
{
152+
if (options is not TOptions providerOptions)
153+
{
154+
throw new InvalidOperationException($"Invalid options type: {options.GetType().Name}. Expected: {typeof(TOptions).Name}");
155+
}
156+
151157
var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk);
152158

153-
var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk);
159+
var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, providerOptions, tempTableRequired: true, ctk: ctk);
154160

155-
var result = await CopyFromTempTableAsync<T>(sync, context, tableName, true, options, onConflict, cancellationToken: ctk);
161+
var result = await CopyFromTempTableAsync<T>(sync, context, tableName, true, providerOptions, onConflict, cancellationToken: ctk);
156162

157163
await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk);
158164

@@ -200,29 +206,36 @@ public virtual async Task BulkInsert<T>(
200206
CancellationToken ctk = default
201207
) where T : class
202208
{
209+
if (options is not TOptions providerOptions)
210+
{
211+
throw new InvalidOperationException($"Invalid options type: {options.GetType().Name}. Expected: {typeof(TOptions).Name}");
212+
}
213+
203214
if (onConflict != null)
204215
{
205216
var (connection, wasClosed, transaction, wasBegan) = await context.GetConnection(sync, ctk);
206217

207-
var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: true, ctk: ctk);
218+
var (tableName, _) = await PerformBulkInsertAsync(sync, context, entities, providerOptions, tempTableRequired: true, ctk: ctk);
208219

209220
await CopyFromTempTableAsync<T>(sync, context, tableName, false, options, onConflict, ctk);
210221

211222
await Finish(sync, connection, wasClosed, transaction, wasBegan, ctk);
212223
}
213224
else
214225
{
215-
await PerformBulkInsertAsync(sync, context, entities, options, tempTableRequired: false, ctk: ctk);
226+
await PerformBulkInsertAsync(sync, context, entities, providerOptions, tempTableRequired: false, ctk: ctk);
216227
}
217228
}
218229

219-
public abstract BulkInsertOptions GetDefaultOptions();
230+
public BulkInsertOptions InternalGetDefaultOptions() => GetDefaultOptions();
231+
232+
protected abstract TOptions GetDefaultOptions();
220233

221234
private async Task<(string TableName, DbConnection Connection)> PerformBulkInsertAsync<T>(
222235
bool sync,
223236
DbContext context,
224237
IEnumerable<T> entities,
225-
BulkInsertOptions options,
238+
TOptions options,
226239
bool tempTableRequired,
227240
CancellationToken ctk = default) where T : class
228241
{
@@ -258,7 +271,7 @@ protected abstract Task BulkInsert<T>(
258271
IEnumerable<T> entities,
259272
string tableName,
260273
PropertyAccessor[] properties,
261-
BulkInsertOptions options,
274+
TOptions options,
262275
CancellationToken ctk
263276
) where T : class;
264277

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
namespace PhenX.EntityFrameworkCore.BulkInsert.Enums;
2+
3+
/// <summary>
4+
/// Enumeration of supported database providers.
5+
/// </summary>
6+
public enum ProviderType
7+
{
8+
/// <summary>
9+
/// SQL Server provider.
10+
/// </summary>
11+
SqlServer,
12+
13+
/// <summary>
14+
/// PostgreSQL provider.
15+
/// </summary>
16+
PostgreSql,
17+
18+
/// <summary>
19+
/// SQLite provider.
20+
/// </summary>
21+
Sqlite,
22+
23+
/// <summary>
24+
/// MySQL provider.
25+
/// </summary>
26+
MySql,
27+
}

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using Microsoft.EntityFrameworkCore.Metadata;
66
using Microsoft.EntityFrameworkCore.Storage;
77

8+
using PhenX.EntityFrameworkCore.BulkInsert.Enums;
9+
810
namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;
911

1012
internal static class DbContextExtensions
@@ -61,4 +63,17 @@ internal static IProperty[] GetProperties(this DbContext context, Type entityTyp
6163

6264
return (connection, wasClosed, transaction, wasBegan);
6365
}
66+
67+
/// <summary>
68+
/// Tells if the current provider is the specified provider type.
69+
/// </summary>
70+
internal static bool IsProvider(this DbContext context, ProviderType providerType)
71+
{
72+
if (context.Database.ProviderName == null)
73+
{
74+
throw new InvalidOperationException("Database provider name is null.");
75+
}
76+
77+
return context.Database.ProviderName.Contains(providerType.ToString(), StringComparison.OrdinalIgnoreCase);
78+
}
6479
}

0 commit comments

Comments
 (0)