Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@

namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;

internal class MySqlBulkInsertProvider : BulkInsertProviderBase<MySqlServerDialectBuilder, MySqlBulkInsertOptions>
internal class MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider> logger) : BulkInsertProviderBase<MySqlServerDialectBuilder, MySqlBulkInsertOptions>(logger)
{
public MySqlBulkInsertProvider(ILogger<MySqlBulkInsertProvider>? logger = null) : base(logger)
{
}

//language=sql
/// <inheritdoc />
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD {BulkInsertId} INT AUTO_INCREMENT PRIMARY KEY;";
Expand Down Expand Up @@ -51,20 +47,16 @@ CancellationToken ctk
)
{
var connection = (MySqlConnection)context.Database.GetDbConnection();

var sqlTransaction = context.Database.CurrentTransaction?.GetDbTransaction()
var sqlTransaction = context.Database.CurrentTransaction!.GetDbTransaction()
?? throw new InvalidOperationException("No open transaction found.");

if (sqlTransaction is not MySqlTransaction mySqlTransaction)
{
throw new InvalidOperationException($"Invalid transaction foud, got {sqlTransaction.GetType()}.");
}

var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction)
{
DestinationTableName = tableName,
BulkCopyTimeout = options.GetCopyTimeoutInSeconds(),
};
var bulkCopy = new MySqlBulkCopy(connection, mySqlTransaction);
bulkCopy.DestinationTableName = tableName;
bulkCopy.BulkCopyTimeout = options.GetCopyTimeoutInSeconds();

var sourceOrdinal = 0;
foreach (var prop in properties)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
using Microsoft.Extensions.Logging;

using Npgsql;
using Npgsql.EntityFrameworkCore.PostgreSQL.Storage.Internal.Mapping;

using NpgsqlTypes;

using PhenX.EntityFrameworkCore.BulkInsert.Metadata;
using PhenX.EntityFrameworkCore.BulkInsert.Options;

namespace PhenX.EntityFrameworkCore.BulkInsert.PostgreSql;

[UsedImplicitly]
internal class PostgreSqlBulkInsertProvider : BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>
internal class PostgreSqlBulkInsertProvider(ILogger<PostgreSqlBulkInsertProvider>? logger) : BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>(logger)

Check warning on line 19 in src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>> logger)'.

Check warning on line 19 in src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>> logger)'.

Check warning on line 19 in src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>> logger)'.

Check warning on line 19 in src/PhenX.EntityFrameworkCore.BulkInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<PostgreSqlDialectBuilder, BulkInsertOptions>> logger)'.
{
public PostgreSqlBulkInsertProvider(ILogger<PostgreSqlBulkInsertProvider>? logger = null) : base(logger)
{
}

//language=sql
/// <inheritdoc />
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD COLUMN {BulkInsertId} SERIAL PRIMARY KEY;";
Expand Down Expand Up @@ -57,6 +56,9 @@
? connection.BeginBinaryImport(command)
: await connection.BeginBinaryImportAsync(command, ctk);

// The type mapping can be null for obvious types like string.
var columnTypes = columns.Select(GetPostgreSqlType).ToArray();

foreach (var entity in entities)
{
if (sync)
Expand All @@ -69,19 +71,40 @@
await writer.StartRowAsync(ctk);
}

var columnIndex = 0;
foreach (var column in columns)
{
var value = column.GetValue(entity);

// Get the actual type, so that the writer can do the conversation to the target type automatically.
var type = columnTypes[columnIndex];

if (sync)
{
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
writer.Write(value);
if (type != null)
{
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
writer.Write(value, type.Value);
}
else
{
// ReSharper disable once MethodHasAsyncOverloadWithCancellation
writer.Write(value);
}
}
else
{
await writer.WriteAsync(value, ctk);
if (type != null)
{
await writer.WriteAsync(value, type.Value, ctk);
}
else
{
await writer.WriteAsync(value, ctk);
}
}

columnIndex++;
}
}

Expand All @@ -97,6 +120,12 @@
await writer.CompleteAsync(ctk);
await writer.DisposeAsync();
}
}

private static NpgsqlDbType? GetPostgreSqlType(ColumnMetadata column)
{
var mapping = column.Property.GetRelationalTypeMapping() as NpgsqlTypeMapping;

return mapping?.NpgsqlDbType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@
using Microsoft.Extensions.Logging;

using PhenX.EntityFrameworkCore.BulkInsert.Metadata;
using PhenX.EntityFrameworkCore.BulkInsert.Options;

namespace PhenX.EntityFrameworkCore.BulkInsert.SqlServer;

[UsedImplicitly]
internal class SqlServerBulkInsertProvider : BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>
internal class SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger) : BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>(logger)

Check warning on line 13 in src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>> logger)'.

Check warning on line 13 in src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>> logger)'.

Check warning on line 13 in src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<SqlServerDialectBuilder, SqlServerBulkInsertOptions>> logger)'.
{
public SqlServerBulkInsertProvider(ILogger<SqlServerBulkInsertProvider>? logger = null) : base(logger)
{
}

//language=sql
/// <inheritdoc />
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD {BulkInsertId} INT IDENTITY PRIMARY KEY;";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.Sqlite;

[UsedImplicitly]
internal class SqliteBulkInsertProvider : BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>
internal class SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger) : BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>(logger)

Check warning on line 16 in src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>> logger)'.

Check warning on line 16 in src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>> logger)'.

Check warning on line 16 in src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteBulkInsertProvider.cs

View workflow job for this annotation

GitHub Actions / build

Possible null reference argument for parameter 'logger' in 'BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>.BulkInsertProviderBase(ILogger<BulkInsertProviderBase<SqliteDialectBuilder, BulkInsertOptions>> logger)'.
{
public SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logger = null) : base(logger)
{
}
private const int MaxParams = 1000;

/// <inheritdoc />
protected override string BulkInsertId => "rowid";
Expand All @@ -40,48 +38,30 @@
CancellationToken cancellationToken
) where T : class => Task.CompletedTask;

/// <summary>
/// Taken from https://github.com/dotnet/efcore/blob/667c569c49a1ab7e142621395d3f14f2af0508b4/src/Microsoft.Data.Sqlite.Core/SqliteValueBinder.cs#L231
/// As the method is not exposed in the public API, we need to copy it here.
/// </summary>
private static readonly Dictionary<Type, SqliteType> SqliteTypeMapping =
new()
{
{ typeof(bool), SqliteType.Integer },
{ typeof(byte), SqliteType.Integer },
{ typeof(byte[]), SqliteType.Blob },
{ typeof(char), SqliteType.Text },
{ typeof(DateTime), SqliteType.Text },
{ typeof(DateTimeOffset), SqliteType.Text },
{ typeof(DateOnly), SqliteType.Text },
{ typeof(TimeOnly), SqliteType.Text },
{ typeof(DBNull), SqliteType.Text },
{ typeof(decimal), SqliteType.Text },
{ typeof(double), SqliteType.Real },
{ typeof(float), SqliteType.Real },
{ typeof(Guid), SqliteType.Text },
{ typeof(int), SqliteType.Integer },
{ typeof(long), SqliteType.Integer },
{ typeof(sbyte), SqliteType.Integer },
{ typeof(short), SqliteType.Integer },
{ typeof(string), SqliteType.Text },
{ typeof(TimeSpan), SqliteType.Text },
{ typeof(uint), SqliteType.Integer },
{ typeof(ulong), SqliteType.Integer },
{ typeof(ushort), SqliteType.Integer }
};

private static SqliteType GetSqliteType(Type clrType)
private static SqliteType GetSqliteType(ColumnMetadata column)
{
var type = Nullable.GetUnderlyingType(clrType) ?? clrType;
type = type.IsEnum ? Enum.GetUnderlyingType(type) : type;
var storeType = column.Property.GetRelationalTypeMapping().StoreType;

if (SqliteTypeMapping.TryGetValue(type, out var sqliteType))
if (string.Equals(storeType, "INTEGER", StringComparison.OrdinalIgnoreCase))
{
return sqliteType;
return SqliteType.Integer;
}
else if (string.Equals(storeType, "FLOAT", StringComparison.OrdinalIgnoreCase))
{
return SqliteType.Real;
}
else if (string.Equals(storeType, "TEXT", StringComparison.OrdinalIgnoreCase))
{
return SqliteType.Text;
}
else if (string.Equals(storeType, "BLOB", StringComparison.OrdinalIgnoreCase))
{
return SqliteType.Blob;
}
else
{
throw new NotSupportedException($"Invalid store type '{storeType}' for property '{column.PropertyName}'");
}

throw new InvalidOperationException($"Unknown Sqlite type for {clrType}");
}

private static DbCommand GetInsertCommand(
Expand Down Expand Up @@ -147,15 +127,13 @@
CancellationToken ctk
) where T : class
{
const int maxParams = 1000;
var batchSize = options.BatchSize;
batchSize = Math.Min(batchSize, maxParams / columns.Count);
var batchSize = Math.Min(options.BatchSize, MaxParams / columns.Count);

// The StringBuilder can be resuse between the batches.
var sb = new StringBuilder();

var columnList = tableInfo.GetColumns(options.CopyGeneratedColumns);
var columnTypes = columnList.Select(c => GetSqliteType(c.ProviderClrType ?? c.ClrType)).ToArray();
var columnTypes = columnList.Select(GetSqliteType).ToArray();

await using var insertCommand =
GetInsertCommand(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

using PhenX.EntityFrameworkCore.BulkInsert.Abstractions;

Expand All @@ -13,17 +16,16 @@ public DbContextOptionsExtensionInfo Info

public void ApplyServices(IServiceCollection services)
{
services.TryAddSingleton(typeof(ILogger<>), typeof(NullLogger<>));
services.AddSingleton<IBulkInsertProvider, TProvider>();
}

public void Validate(IDbContextOptions options)
{
}

private class BulkInsertOptionsExtensionInfo : DbContextOptionsExtensionInfo
private class BulkInsertOptionsExtensionInfo(IDbContextOptionsExtension extension) : DbContextOptionsExtensionInfo(extension)
{
public BulkInsertOptionsExtensionInfo(IDbContextOptionsExtension extension)
: base(extension) { }

/// <inheritdoc />
public override int GetServiceProviderHashCode() => 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace PhenX.EntityFrameworkCore.BulkInsert;

internal abstract class BulkInsertProviderBase<TDialect, TOptions>(ILogger<BulkInsertProviderBase<TDialect, TOptions>>? logger = null) : IBulkInsertProvider
internal abstract class BulkInsertProviderBase<TDialect, TOptions>(ILogger<BulkInsertProviderBase<TDialect, TOptions>> logger) : IBulkInsertProvider
where TDialect : SqlDialectBuilder, new()
where TOptions : BulkInsertOptions, new()
{
Expand Down Expand Up @@ -56,7 +56,7 @@ public virtual async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
{
if (logger != null)
{
Log.UsingTempTablToReturnData(logger);
Log.UsingTempTableToReturnData(logger);
}

var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, providerOptions, tempTableRequired: true, ctk: ctk);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ internal static TableMetadata GetTableInfo<T>(this DbContext context)
internal static DbContextOptionsBuilder UseProvider<TProvider>(this DbContextOptionsBuilder optionsBuilder)
where TProvider : class, IBulkInsertProvider
{
var extension = optionsBuilder.Options.FindExtension<BulkInsertOptionsExtension<TProvider>>() ?? new BulkInsertOptionsExtension<TProvider>();

((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(extension);
((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(
optionsBuilder.Options.FindExtension<BulkInsertOptionsExtension<TProvider>>() ?? new());

((IDbContextOptionsBuilderInfrastructure)optionsBuilder).AddOrUpdateExtension(
optionsBuilder.Options.FindExtension<MetadataProviderExtension>() ?? new());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore;

using PhenX.EntityFrameworkCore.BulkInsert.Options;

Expand Down Expand Up @@ -107,7 +107,7 @@ public static IAsyncEnumerable<T> ExecuteBulkInsertReturnEnumerableAsync<T, TOpt
where T : class
where TOptions : BulkInsertOptions
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var (provider, context, options) = InitProvider(dbSet, configure);

return provider.BulkInsertReturnEntities(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities,
options, onConflict, ctk);
Expand Down Expand Up @@ -155,7 +155,7 @@ public static async Task ExecuteBulkInsertAsync<T, TOptions>(
where T : class
where TOptions : BulkInsertOptions
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var (provider, context, options) = InitProvider(dbSet, configure);

await provider.BulkInsert(false, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict,
ctk);
Expand Down Expand Up @@ -202,7 +202,7 @@ public static void ExecuteBulkInsert<T, TOptions>(
where T : class
where TOptions : BulkInsertOptions
{
var provider = InitProvider(dbSet, configure, out var context, out var options);
var (provider, context, options) = InitProvider(dbSet, configure);

provider.BulkInsert(true, context, dbSet.GetDbContext().GetTableInfo<T>(), entities, options, onConflict)
.GetAwaiter().GetResult();
Expand Down
Loading