Skip to content

Commit d8ac159

Browse files
Merge branch 'master' of github.com:SebastianStehle/PhenX.EntityFrameworkCore.BulkInsert into generic-base-provider
# Conflicts: # src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs # src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs # src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs # src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs # src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertOptionsExtension.cs # src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs # src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/PublicExtensions.cs # tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs # tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsMySql.cs # tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsPostgreSql.cs # tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsSqlServer.cs # tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsSqlite.cs
2 parents 21fe004 + 685ff84 commit d8ac159

22 files changed

Lines changed: 267 additions & 117 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlBulkInsertProvider.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using JetBrains.Annotations;
2+
13
using Microsoft.EntityFrameworkCore;
24
using Microsoft.EntityFrameworkCore.Storage;
35
using Microsoft.Extensions.Logging;
@@ -9,7 +11,8 @@
911

1012
namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;
1113

12-
internal class MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null) : BulkInsertProviderBase<MySqlServerDialectBuilder, MySqlBulkInsertOptions>(logger)
14+
[UsedImplicitly]
15+
internal class MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider> logger) : BulkInsertProviderBase<MySqlServerDialectBuilder, MySqlBulkInsertOptions>(logger)
1316
{
1417
//language=sql
1518
/// <inheritdoc />
@@ -47,20 +50,16 @@ CancellationToken ctk
4750
)
4851
{
4952
var connection = (MySqlConnection)context.Database.GetDbConnection();
50-
51-
var sqlTransaction = context.Database.CurrentTransaction?.GetDbTransaction()
53+
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction()
5254
?? throw new InvalidOperationException("No open transaction found.");
53-
5455
if (sqlTransaction is not MySqlTransaction mySqlTransaction)
5556
{
5657
throw new InvalidOperationException($"Invalid transaction foud, got {sqlTransaction.GetType()}.");
5758
}
5859

59-
var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction)
60-
{
61-
DestinationTableName = tableName,
62-
BulkCopyTimeout = options.GetCopyTimeoutInSeconds(),
63-
};
60+
var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction);
61+
bulkCopy.DestinationTableName = tableName;
62+
bulkCopy.BulkCopyTimeout = options.GetCopyTimeoutInSeconds();
6463

6564
var sourceOrdinal = 0;
6665
foreach (var prop in properties)

src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66
using Microsoft.Extensions.Logging;
77

88
using Npgsql;
9+
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;
10+
11+
using NpgsqlTypes;
912

1013
using PhenX.EntityFrameworkCore.BulkInsert.Metadata;
1114
using PhenX.EntityFrameworkCore.BulkInsert.Options;
1215

1316
namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql;
1417

1518
[UsedImplicitly]
16-
internal class PostgreSqlBulkInsertProvider(ILogger<PostgreSqlBulkInsertProvider>? logger = null) : BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>(logger)
19+
internal class PostgreSqlBulkInsertProvider(ILogger<PostgreSqlBulkInsertProvider>? logger) : BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>(logger)
1720
{
1821
//language=sql
1922
/// <inheritdoc />
@@ -53,6 +56,9 @@ protected override async Task BulkInsert<T>(
5356
? connection.BeginBinaryImport(command)
5457
: await connection.BeginBinaryImportAsync(command, ctk);
5558

59+
// The type mapping can be null for obvious types like string.
60+
var columnTypes = columns.Select(GetPostgreSqlType).ToArray();
61+
5662
foreach (var entity in entities)
5763
{
5864
if (sync)
@@ -65,19 +71,40 @@ protected override async Task BulkInsert<T>(
6571
await writer.StartRowAsync(ctk);
6672
}
6773

74+
var columnIndex = 0;
6875
foreach (var column in columns)
6976
{
7077
var value = column.GetValue(entity);
7178

79+
// Get the actual type, so that the writer can do the conversation to the target type automatically.
80+
var type = columnTypes[columnIndex];
81+
7282
if (sync)
7383
{
74-
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
75-
writer.Write(value);
84+
if (type != null)
85+
{
86+
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
87+
writer.Write(value, type.Value);
88+
}
89+
else
90+
{
91+
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
92+
writer.Write(value);
93+
}
7694
}
7795
else
7896
{
79-
await writer.WriteAsync(value, ctk);
97+
if (type != null)
98+
{
99+
await writer.WriteAsync(value, type.Value, ctk);
100+
}
101+
else
102+
{
103+
await writer.WriteAsync(value, ctk);
104+
}
80105
}
106+
107+
columnIndex++;
81108
}
82109
}
83110

@@ -93,6 +120,12 @@ protected override async Task BulkInsert<T>(
93120
await writer.CompleteAsync(ctk);
94121
await writer.DisposeAsync();
95122
}
123+
}
124+
125+
private static NpgsqlDbType? GetPostgreSqlType(ColumnMetadata column)
126+
{
127+
var mapping = column.Property.GetRelationalTypeMapping() as NpgsqlTypeMapping;
96128

129+
return mapping?.NpgsqlDbType;
97130
}
98131
}

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@
66
using Microsoft.Extensions.Logging;
77

88
using PhenX.EntityFrameworkCore.BulkInsert.Metadata;
9-
using PhenX.EntityFrameworkCore.BulkInsert.Options;
109

1110
namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;
1211

1312
[UsedImplicitly]
14-
internal class SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger = null) : BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>(logger)
13+
internal class SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger) : BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>(logger)
1514
{
16-
1715
//language=sql
1816
/// <inheritdoc />
1917
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD {BulkInsertId} INT IDENTITY PRIMARY KEY;";

src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite;
1414

1515
[UsedImplicitly]
16-
internal class SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger = null) : BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>(logger)
16+
internal class SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger) : BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>(logger)
1717
{
18+
private const int MaxParams = 1000;
1819

1920
/// <inheritdoc />
2021
protected override string BulkInsertId => "rowid";
@@ -37,48 +38,30 @@ protected override Task AddBulkInsertIdColumn<T>(
3738
CancellationToken cancellationToken
3839
) where T : class => Task.CompletedTask;
3940

40-
/// <summary>
41-
/// Taken from https://github.com/dotnet/efcore/blob/667c569c49a1ab7e142621395d3f14f2af0508b4/src/Microsoft.Data.Sqlite.Core/SqliteValueBinder.cs#L231
42-
/// As the method is not exposed in the public API, we need to copy it here.
43-
/// </summary>
44-
private static readonly Dictionary<Type, SqliteType> SqliteTypeMapping =
45-
new()
46-
{
47-
{ typeof(bool), SqliteType.Integer },
48-
{ typeof(byte), SqliteType.Integer },
49-
{ typeof(byte[]), SqliteType.Blob },
50-
{ typeof(char), SqliteType.Text },
51-
{ typeof(DateTime), SqliteType.Text },
52-
{ typeof(DateTimeOffset), SqliteType.Text },
53-
{ typeof(DateOnly), SqliteType.Text },
54-
{ typeof(TimeOnly), SqliteType.Text },
55-
{ typeof(DBNull), SqliteType.Text },
56-
{ typeof(decimal), SqliteType.Text },
57-
{ typeof(double), SqliteType.Real },
58-
{ typeof(float), SqliteType.Real },
59-
{ typeof(Guid), SqliteType.Text },
60-
{ typeof(int), SqliteType.Integer },
61-
{ typeof(long), SqliteType.Integer },
62-
{ typeof(sbyte), SqliteType.Integer },
63-
{ typeof(short), SqliteType.Integer },
64-
{ typeof(string), SqliteType.Text },
65-
{ typeof(TimeSpan), SqliteType.Text },
66-
{ typeof(uint), SqliteType.Integer },
67-
{ typeof(ulong), SqliteType.Integer },
68-
{ typeof(ushort), SqliteType.Integer }
69-
};
70-
71-
private static SqliteType GetSqliteType(Type clrType)
41+
private static SqliteType GetSqliteType(ColumnMetadata column)
7242
{
73-
var type = Nullable.GetUnderlyingType(clrType) ?? clrType;
74-
type = type.IsEnum ? Enum.GetUnderlyingType(type) : type;
43+
var storeType = column.Property.GetRelationalTypeMapping().StoreType;
7544

76-
if (SqliteTypeMapping.TryGetValue(type, out var sqliteType))
45+
if (string.Equals(storeType, "INTEGER", StringComparison.OrdinalIgnoreCase))
7746
{
78-
return sqliteType;
47+
return SqliteType.Integer;
48+
}
49+
else if (string.Equals(storeType, "FLOAT", StringComparison.OrdinalIgnoreCase))
50+
{
51+
return SqliteType.Real;
52+
}
53+
else if (string.Equals(storeType, "TEXT", StringComparison.OrdinalIgnoreCase))
54+
{
55+
return SqliteType.Text;
56+
}
57+
else if (string.Equals(storeType, "BLOB", StringComparison.OrdinalIgnoreCase))
58+
{
59+
return SqliteType.Blob;
60+
}
61+
else
62+
{
63+
throw new NotSupportedException($"Invalid store type '{storeType}' for property '{column.PropertyName}'");
7964
}
80-
81-
throw new InvalidOperationException($"Unknown Sqlite type for {clrType}");
8265
}
8366

8467
private static DbCommand GetInsertCommand(
@@ -144,15 +127,13 @@ protected override async Task BulkInsert<T>(
144127
CancellationToken ctk
145128
) where T : class
146129
{
147-
const int maxParams = 1000;
148-
var batchSize = options.BatchSize;
149-
batchSize = Math.Min(batchSize, maxParams / columns.Count);
130+
var batchSize = Math.Min(options.BatchSize, MaxParams / columns.Count);
150131

151132
// The StringBuilder can be resuse between the batches.
152133
var sb = new StringBuilder();
153134

154135
var columnList = tableInfo.GetColumns(options.CopyGeneratedColumns);
155-
var columnTypes = columnList.Select(c => GetSqliteType(c.ProviderClrType ?? c.ClrType)).ToArray();
136+
var columnTypes = columnList.Select(GetSqliteType).ToArray();
156137

157138
await using var insertCommand =
158139
GetInsertCommand(

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertOptionsExtension.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using Microsoft.EntityFrameworkCore.Infrastructure;
22
using Microsoft.Extensions.DependencyInjection;
3+
using Microsoft.Extensions.DependencyInjection.Extensions;
4+
using Microsoft.Extensions.Logging;
5+
using Microsoft.Extensions.Logging.Abstractions;
36

47
using PhenX.EntityFrameworkCore.BulkInsert.Abstractions;
58

@@ -13,6 +16,7 @@ public DbContextOptionsExtensionInfo Info
1316

1417
public void ApplyServices(IServiceCollection services)
1518
{
19+
services.TryAddSingleton(typeof(ILogger<>), typeof(NullLogger<>));
1620
services.AddSingleton<IBulkInsertProvider, TProvider>();
1721
}
1822

src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace PhenX.EntityFrameworkCore.BulkInsert;
1313

14-
internal abstract class BulkInsertProviderBase<TDialect, TOptions>(ILogger? logger = null) : BulkInsertProviderUntyped<TDialect, TOptions>
14+
internal abstract class BulkInsertProviderBase<TDialect, TOptions>(ILogger? logger) : BulkInsertProviderUntyped<TDialect, TOptions>
1515
where TDialect : SqlDialectBuilder, new()
1616
where TOptions : BulkInsertOptions, new()
1717
{
@@ -39,7 +39,7 @@ protected override async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
3939
{
4040
if (logger != null)
4141
{
42-
Log.UsingTempTablToReturnData(logger);
42+
Log.UsingTempTableToReturnData(logger);
4343
}
4444

4545
var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/InternalExtensions.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ internal static TableMetadata GetTableInfo<T>(this DbContext context)
2222
internal static DbContextOptionsBuilder UseProvider<TProvider>(this DbContextOptionsBuilder optionsBuilder)
2323
where TProvider : class, IBulkInsertProvider
2424
{
25-
var extension = optionsBuilder.Options.FindExtension<BulkInsertOptionsExtension<TProvider>>() ?? new BulkInsertOptionsExtension<TProvider>();
26-
27-
((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension);
25+
((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(
26+
optionsBuilder.Options.FindExtension<BulkInsertOptionsExtension<TProvider>>() ?? new());
2827

2928
((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(
3029
optionsBuilder.Options.FindExtension<MetadataProviderExtension>() ?? new());

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/PublicExtensions.DbSet.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T, TOpt
107107
where T : class
108108
where TOptions : BulkInsertOptions
109109
{
110-
var provider = InitProvider(dbSet, configure, out var context, out var options);
110+
var (provider, context, options) = InitProvider(dbSet, configure);
111111

112112
return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities,
113113
options, onConflict, cancellationToken);
@@ -155,7 +155,7 @@ public static async Task ExecuteBulkInsertAsync<T, TOptions>(
155155
where T : class
156156
where TOptions : BulkInsertOptions
157157
{
158-
var provider = InitProvider(dbSet, configure, out var context, out var options);
158+
var (provider, context, options) = InitProvider(dbSet, configure);
159159

160160
await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict,
161161
cancellationToken);
@@ -202,7 +202,7 @@ public static void ExecuteBulkInsert<T, TOptions>(
202202
where T : class
203203
where TOptions : BulkInsertOptions
204204
{
205-
var provider = InitProvider(dbSet, configure, out var context, out var options);
205+
var (provider, context, options) = InitProvider(dbSet, configure);
206206

207207
provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict)
208208
.GetAwaiter().GetResult();

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/PublicExtensions.cs

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Extensions;
1111
/// </summary>
1212
public static partial class PublicExtensions
1313
{
14-
private static async Task<List<T>> ExecuteBulkInsertReturnEntitiesCoreAsync<T, TOptions>(
15-
this DbSet<T> dbSet,
14+
private static async Task<List<TEntity>> ExecuteBulkInsertReturnEntitiesCoreAsync<TEntity, TOptions>(
15+
this DbSet<TEntity> dbSet,
1616
bool sync,
17-
IEnumerable<T> entities,
17+
IEnumerable<TEntity> entities,
1818
Action<TOptions> configure,
19-
OnConflictOptions<T>? onConflict,
19+
OnConflictOptions<TEntity>? onConflict,
2020
CancellationToken ctk
2121
)
22-
where T : class
22+
where TEntity : class
2323
where TOptions : BulkInsertOptions
2424
{
25-
var provider = InitProvider(dbSet, configure, out var context, out var options);
25+
var (provider, context, options) = InitProvider(dbSet, configure);
2626

27-
var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict, ctk);
27+
var enumerable = provider.BulkInsertReturnEntities(sync, context, dbSet.GetDbContext().GetTableInfo<TEntity>(), entities, options, onConflict, ctk);
2828

29-
var result = new List<T>();
29+
var result = new List<TEntity>();
3030
await foreach (var item in enumerable.WithCancellation(ctk))
3131
{
3232
result.Add(item);
@@ -41,27 +41,22 @@ private static DbContext GetDbContext<T>(this DbSet<T> dbSet) where T : class
4141
return (infrastructure.Instance.GetService(typeof(ICurrentDbContext)) as ICurrentDbContext)!.Context;
4242
}
4343

44-
private static IBulkInsertProvider InitProvider<T, TOptions>(
44+
private static (IBulkInsertProvider, DbContext, TOptions) InitProvider<T, TOptions>(
4545
DbSet<T> dbSet,
46-
Action<TOptions>? configure,
47-
out DbContext context,
48-
out TOptions options
46+
Action<TOptions>? configure
4947
)
5048
where T : class where TOptions : BulkInsertOptions
5149
{
52-
context = dbSet.GetDbContext();
50+
var context = dbSet.GetDbContext();
5351
var provider = context.GetService<IBulkInsertProvider>();
54-
55-
var defaultOptions = provider.CreateDefaultOptions();
56-
57-
if (defaultOptions is not TOptions castedOptions)
52+
var options = provider.CreateDefaultOptions();
53+
if (options is not TOptions castedOptions)
5854
{
59-
throw new InvalidOperationException($"Options type mismatch. Expected {defaultOptions.GetType().Name}, but got {typeof(TOptions).Name}.");
55+
throw new InvalidOperationException($"Options type mismatch. Expected {options.GetType().Name}, but got {typeof(TOptions).Name}.");
6056
}
6157

62-
options = castedOptions;
63-
configure?.Invoke(options);
58+
configure?.Invoke(castedOptions);
6459

65-
return provider;
60+
return (provider, context, castedOptions);
6661
}
6762
}

src/PhenX.EntityFrameworkCore.BulkInsert/Log.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ internal static partial class Log
88
EventId = 1000,
99
Level = LogLevel.Trace,
1010
Message = "Using temporary table to return data")]
11-
public static partial void UsingTempTablToReturnData(ILogger logger);
11+
public static partial void UsingTempTableToReturnData(ILogger logger);
1212

1313
[LoggerMessage(
1414
EventId = 1001,

0 commit comments

Comments
 (0)