Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ internal class SqliteBulkInsertProvider(ILogger<SqliteBulkInsertProvider>? logge
/// <inheritdoc />
protected override string BulkInsertId => "rowid";

//language=sql
/// <inheritdoc />
protected override string AddTableCopyBulkInsertId => "--"; // No need to add an ID column in SQLite

/// <inheritdoc />
protected override string GetTempTableName(string tableName) => $"_temp_bulk_insert_test_entity_{Guid.NewGuid():N}";

/// <inheritdoc />
protected override BulkInsertOptions CreateDefaultOptions() => new()
{
Expand Down Expand Up @@ -116,6 +118,12 @@ private static DbCommand GetInsertCommand(
return command;
}

/// <inheritdoc />
protected override Task DropTempTableAsync(bool sync, DbContext dbContext, string tableName)
{
return ExecuteAsync(sync, dbContext, $"DROP TABLE IF EXISTS {tableName}", default);
}

/// <inheritdoc />
protected override async Task BulkInsert<T>(
bool sync,
Expand Down
67 changes: 53 additions & 14 deletions src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
Expand Down Expand Up @@ -43,14 +44,20 @@ protected override async IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
}

var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);
try
{
var result =
await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk)
?? throw new InvalidOperationException("Copy returns null enumerable.");

var result =
await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, true, options, onConflict, ctk: ctk)
?? throw new InvalidOperationException("Copy returns null enumerable.");

await foreach (var item in result.WithCancellation(ctk))
await foreach (var item in result.WithCancellation(ctk))
{
yield return item;
}
}
finally
{
yield return item;
await PerformDropTempTableAsync(sync, context, tableName);
}

// Commit the transaction if we own them.
Expand All @@ -71,6 +78,11 @@ protected override async Task BulkInsert<T>(
OnConflictOptions<T>? onConflict,
CancellationToken ctk) where T : class
{
if (entities.TryGetNonEnumeratedCount(out var count) && count == 0)
{
throw new InvalidOperationException("No entities to insert.");
}

using var activity = Telemetry.ActivitySource.StartActivity("BulkInsert");
activity?.AddTag("tableName", tableInfo.TableName);
activity?.AddTag("synchronous", sync);
Expand All @@ -86,8 +98,14 @@ protected override async Task BulkInsert<T>(
}

var tableName = await PerformBulkInsertAsync(sync, context, tableInfo, entities, options, tempTableRequired: true, ctk: ctk);

await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, false, options, onConflict, ctk);
try
{
await CopyFromTempTableAsync<T, T>(sync, context, tableInfo, tableName, false, options, onConflict, ctk);
}
finally
{
await PerformDropTempTableAsync(sync, context, tableName);
}
}
else
{
Expand All @@ -107,7 +125,6 @@ protected override async Task BulkInsert<T>(
await connection.Close(sync, ctk);
}
}

private async Task<string> PerformBulkInsertAsync<T>(
bool sync,
DbContext context,
Expand All @@ -117,11 +134,6 @@ private async Task<string> PerformBulkInsertAsync<T>(
bool tempTableRequired,
CancellationToken ctk) where T : class
{
if (entities.TryGetNonEnumeratedCount(out var count) && count == 0)
{
throw new InvalidOperationException("No entities to insert.");
}

var tableName = tempTableRequired
? await CreateTableCopyAsync<T>(sync, context, options, tableInfo, ctk)
: tableInfo.QuotedTableName;
Expand Down Expand Up @@ -208,6 +220,33 @@ protected virtual async Task AddBulkInsertIdColumn<T>(
return null;
}

private async Task PerformDropTempTableAsync(bool sync, DbContext dbContext, string tableName)
{
try
{
await DropTempTableAsync(sync, dbContext, tableName);
}
catch (Exception ex)
{
// The drop operation is not mandatory, therefore never fail the actual operation.
if (logger != null)
{
Log.DropTemporaryTableFailed(logger, ex);
}
}
}

/// <summary>
/// Drops the temporary table manually if needed.
/// </summary>
/// <param name="sync">Indicates if the operation is synchronous.</param>
/// <param name="dbContext">The context.</param>
/// <param name="tableName">The table name.</param>
protected virtual Task DropTempTableAsync(bool sync, DbContext dbContext, string tableName)
{
return Task.CompletedTask;
}

protected static async Task ExecuteAsync(
bool sync,
DbContext context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public IAsyncEnumerable<T> BulkInsertReturnEntities<T>(
throw new InvalidOperationException($"Invalid options type: {options.GetType().Name}. Expected: {typeof(TOptions).Name}");
}

if (entities.TryGetNonEnumeratedCount(out var count) && count == 0)
{
throw new InvalidOperationException("No entities to insert.");
}

return BulkInsertReturnEntities(sync, context, tableInfo, entities, providerOptions, onConflict, ctk);
}

Expand All @@ -62,6 +67,11 @@ public Task BulkInsert<T>(
throw new InvalidOperationException($"Invalid options type: {options.GetType().Name}. Expected: {typeof(TOptions).Name}");
}

if (entities.TryGetNonEnumeratedCount(out var count) && count == 0)
{
throw new InvalidOperationException("No entities to insert.");
}

return BulkInsert(sync, context, tableInfo, entities, providerOptions, onConflict, ctk);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public virtual string BuildMoveDataSql<T>(

if (returnedColumns.Count != 0)
{
q.Append("RETURNING ");
q.Append(" RETURNING ");
q.AppendJoin(", ", returnedColumns.Select(p => p.QuotedColumName));
q.AppendLine();
}
Expand Down
6 changes: 6 additions & 0 deletions src/PhenX.EntityFrameworkCore.BulkInsert/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ internal static partial class Log
Level = LogLevel.Trace,
Message = "Insert to table directly")]
public static partial void UsingDirectInsert(ILogger logger);

[LoggerMessage(
EventId = 1003,
Level = LogLevel.Error,
Message = "Failed to drop temporary table.")]
public static partial void DropTemporaryTableFailed(ILogger logger, Exception exception);
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;
[PrimaryKey(nameof(Id))]
[Index(nameof(Name), IsUnique = true)]
[Table("test_entity")]
public class TestEntity
public class TestEntity : TestEntityBase
{
public int Id { get; set; }

Expand All @@ -19,9 +19,6 @@ public class TestEntity
[Column("some_price")]
public decimal Price { get; set; }

[Column("test_run")]
public Guid TestRun { get; set; }

[Column("the_identifier")]
public Guid Identifier { get; set; }

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using System.ComponentModel.DataAnnotations.Schema;

namespace PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;

public abstract class TestEntityBase
{
[Column("test_run")]
public Guid TestRun { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;

[Table("test_entity_with_converters")]
public class TestEntityWithConverters
public class TestEntityWithConverters : TestEntityBase
{
public int Id { get; set; }

Expand All @@ -14,8 +14,5 @@ public class TestEntityWithConverters

[Column("created_at")]
public DateTime CreatedAt { get; set; }

[Column("test_run")]
public Guid TestRun { get; set; }
}

Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations.Schema;

using NetTopologySuite.Geometries;

namespace PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;

[Table("test_entity_geo")]
public class TestEntityWithGeo
public class TestEntityWithGeo : TestEntityBase
{
[Key]
public int Id { get; set; }

public Geometry GeoObject { get; set; } = null!;

[Column("test_run")]
public Guid TestRun { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;

[Table("test_entity_guids")]
public class TestEntityWithGuidId
public class TestEntityWithGuidId : TestEntityBase
{
[Key]
public Guid Id { get; set; }

[Column("name")]
[MaxLength(100)]
public string Name { get; set; } = string.Empty;

[Column("test_run")]
public Guid TestRun { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
namespace PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;

[Table("test_entity_json")]
public class TestEntityWithJson
public class TestEntityWithJson : TestEntityBase
{
[Key]
public int Id { get; set; }

public List<int> Json { get; set; } = [];

[Column("test_run")]
public Guid TestRun { get; set; }
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net8.0;net9.0</TargetFrameworks>
Expand All @@ -13,8 +13,10 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="FluentAssertions" Version="7.2.0" />
Comment thread
PhenX marked this conversation as resolved.
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.13.0" />
<PackageReference Include="xunit" Version="2.9.3" />
<PackageReference Include="Xunit.Combinatorial" Version="1.6.24" />
<PackageReference Include="xunit.runner.visualstudio" Version="3.0.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
72 changes: 72 additions & 0 deletions tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/TestHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Microsoft.EntityFrameworkCore;

using PhenX.EntityFrameworkCore.BulkInsert.Enums;
using PhenX.EntityFrameworkCore.BulkInsert.Extensions;
using PhenX.EntityFrameworkCore.BulkInsert.Options;
using PhenX.EntityFrameworkCore.BulkInsert.Tests.DbContext;

using Xunit;

namespace PhenX.EntityFrameworkCore.BulkInsert.Tests;

public enum InsertStrategy
{
Insert,
InsertReturn,
InsertAsync,
InsertReturnAsync
}

public static class TestHelpers
{
public static async Task<List<T>> InsertWithStrategyAsync<T>(
this TestDbContextBase dbContext,
InsertStrategy strategy,
List<T> entities,
Action<BulkInsertOptions>? configure = null,
OnConflictOptions<T>? onConflict = null)
where T : TestEntityBase
{
Skip.If(strategy is InsertStrategy.InsertReturn or InsertStrategy.InsertReturnAsync && dbContext.IsProvider(ProviderType.MySql));

var runId = Guid.NewGuid();
if (entities.Any(x => x.TestRun == default))
{
foreach (var entity in entities)
{
if (entity.TestRun == default)
{
entity.TestRun = runId;
}
}
}
else if (entities.Count > 0)
{
runId = entities[0].TestRun;
}

var actualConfigure = configure ?? (_ => { });
try
{
switch (strategy)
{
case InsertStrategy.InsertReturn:
return dbContext.ExecuteBulkInsertReturnEntities(entities, actualConfigure, onConflict);
case InsertStrategy.InsertReturnAsync:
return await dbContext.ExecuteBulkInsertReturnEntitiesAsync(entities, actualConfigure, onConflict);
case InsertStrategy.Insert:
dbContext.ExecuteBulkInsert(entities, actualConfigure, onConflict);
return dbContext.Set<T>().Where(x => x.TestRun == runId).ToList();
case InsertStrategy.InsertAsync:
await dbContext.ExecuteBulkInsertAsync(entities, actualConfigure, onConflict);
return await dbContext.Set<T>().Where(x => x.TestRun == runId).ToListAsync();
default:
throw new NotImplementedException();
}
}
finally
{
dbContext.ChangeTracker.Clear();
}
}
}
Loading