diff --git a/README.md b/README.md index 9a5a0f0..7414107 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Install-Package PhenX.EntityFrameworkCore.BulkInsert.MySql ## Usage -1. Register the bulk insert provider in your `DbContextOptions`: +Register the bulk insert provider in your `DbContextOptions`: ```csharp services.AddDbContext(options => @@ -62,7 +62,7 @@ services.AddDbContext(options => }); ``` -2. Use the bulk insert extension method: +### Very basic usage ```csharp // Asynchronously @@ -72,7 +72,7 @@ await dbContext.ExecuteBulkInsertAsync(entities); dbContext.ExecuteBulkInsert(entities); ``` -3. You can also configure the bulk insert options: +### Bulk insert with options ```csharp // Common options @@ -103,12 +103,47 @@ await dbContext.ExecuteBulkInsertAsync(entities, o => }); ``` -4. You can also return the inserted entities (slower): +### Returning inserted entities ```csharp await dbContext.ExecuteBulkInsertReturnEntitiesAsync(entities); ``` +### Conflict resolution / merge / upsert + +Conflict resolution works by specifying columns that should be used to detect conflicts and the action to take when +a conflict is detected (e.g., update existing rows), using the `onConflict` parameter. + + * The conflicting columns are specified with the `Match` property and must have a unique constraint in the database. + * The action to take when a conflict is detected is specified with the `Update` property. If not specified, the default action is to do nothing (i.e., skip the conflicting rows). + * You can also specify the condition for the update action using either the `Where` or the `RawWhere` property. If not specified, the update action will be applied to all conflicting rows. + +```csharp +await dbContext.ExecuteBulkInsertAsync(entities, onConflict: new OnConflictOptions +{ + Match = e => new + { + e.Name, + // ...other columns to match on + }, + + // Optional: specify the update action, if not specified, the default action is to do nothing + // Excluded is the row being inserted which is in conflict, and Inserted is the row already in the database. + Update = (inserted, excluded) => new TestEntity + { + Price = inserted.Price // Update the Price column with the new value + }, + + // Optional: specify the condition for the update action + // Excluded is the row being inserted which is in conflict, and Inserted is the row already in the database. + // Using raw SQL condition + RawWhere = (insertedTable, excludedTable) => $"{excludedTable}.some_price > {insertedTable}.some_price", + + // OR using a lambda expression + Where = (inserted, excluded) => excluded.Price > inserted.Price, +}); +``` + ## Roadmap - [ ] [Add support for navigation properties](https://github.com/PhenX/PhenX.EntityFrameworkCore.BulkInsert/issues/2) diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs index 62ed057..c16cae1 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs @@ -1,5 +1,7 @@ using System.Text; +using Microsoft.EntityFrameworkCore; + using PhenX.EntityFrameworkCore.BulkInsert.Dialect; using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; @@ -12,14 +14,22 @@ internal class MySqlServerDialectBuilder : SqlDialectBuilder protected override string CloseDelimiter => "`"; + /// protected override bool SupportsMoveRows => false; + /// + protected override bool SupportsInsertIntoAlias => false; + public override string CreateTableCopySql(string tempNameName, TableMetadata tableInfo, IReadOnlyList columns) { return $"CREATE TEMPORARY TABLE {tempNameName} SELECT * FROM {tableInfo.QuotedTableName} WHERE 1 = 0;"; } - protected override void AppendConflictCondition(StringBuilder sql, OnConflictOptions onConflictTyped) + protected override void AppendConflictCondition( + StringBuilder sql, + TableMetadata target, + DbContext context, + OnConflictOptions onConflictTyped) { throw new NotSupportedException("Conflict conditions are not supported in MYSQL"); } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs index e23c346..8ed724c 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs @@ -1,5 +1,7 @@ using System.Text; +using Microsoft.EntityFrameworkCore; + using PhenX.EntityFrameworkCore.BulkInsert.Dialect; using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; @@ -34,7 +36,10 @@ public override string CreateTableCopySql(string tempTableName, TableMetadata ta return q.ToString(); } + protected override string Trim(string lhs) => $"TRIM({lhs})"; + public override string BuildMoveDataSql( + DbContext context, TableMetadata target, string source, IReadOnlyList insertedColumns, @@ -66,24 +71,37 @@ public override string BuildMoveDataSql( throw new InvalidOperationException("Table has no primary key that can be used for conflict detection."); } - q.AppendLine($"MERGE INTO {target.QuotedTableName} AS TARGET"); + q.AppendLine($"MERGE INTO {target.QuotedTableName} AS {PseudoTableInserted}"); q.Append("USING (SELECT "); q.AppendColumns(insertedColumns); - q.Append($" FROM {source}) AS SOURCE ("); + q.Append($" FROM {source}) AS {PseudoTableExcluded} ("); q.AppendColumns(insertedColumns); q.AppendLine(")"); q.Append("ON "); - q.AppendJoin(" AND ", matchColumns, (b, col) => b.Append($"TARGET.{col} = SOURCE.{col}")); + q.AppendJoin(" AND ", matchColumns, (b, col) => b.Append($"{PseudoTableInserted}.{col} = {PseudoTableExcluded}.{col}")); q.AppendLine(); if (onConflictTyped.Update != null) { var columns = target.GetColumns(false); - q.AppendLine("WHEN MATCHED THEN UPDATE SET "); - q.AppendJoin(", ", GetUpdates(target, columns, onConflictTyped.Update)); + q.AppendLine("WHEN MATCHED "); + + if (onConflictTyped.RawWhere != null || onConflictTyped.Where != null) + { + if (onConflictTyped is { RawWhere: not null, Where: not null }) + { + throw new ArgumentException("Cannot specify both RawWhere and Where in OnConflictOptions."); + } + + q.Append("AND "); + AppendConflictCondition(q, target, context, onConflictTyped); + } + + q.AppendLine("THEN UPDATE SET "); + q.AppendJoin(", ", GetUpdates(context, target, columns, onConflictTyped.Update)); q.AppendLine(); } @@ -92,13 +110,13 @@ public override string BuildMoveDataSql( q.AppendLine(")"); q.Append("VALUES ("); - q.AppendJoin(", ", insertedColumns, (b, col) => b.Append($"SOURCE.{col.QuotedColumName}")); + q.AppendJoin(", ", insertedColumns, (b, col) => b.Append($"{PseudoTableExcluded}.{col.QuotedColumName}")); q.AppendLine(")"); if (returnedColumns.Count != 0) { q.Append("OUTPUT "); - q.AppendJoin(", ", returnedColumns, (b, col) => b.Append($"INSERTED.{col.QuotedColumName} AS {col.QuotedColumName}")); + q.AppendJoin(", ", returnedColumns, (b, col) => b.Append($"{PseudoTableInserted}.{col.QuotedColumName} AS {col.QuotedColumName}")); q.AppendLine(); } } @@ -113,7 +131,7 @@ public override string BuildMoveDataSql( if (returnedColumns.Count != 0) { q.Append("OUTPUT "); - q.AppendJoin(", ", returnedColumns, (b, col) => b.Append($"INSERTED.{col.QuotedColumName} AS {col.QuotedColumName}")); + q.AppendJoin(", ", returnedColumns, (b, col) => b.Append($"{PseudoTableInserted}.{col.QuotedColumName} AS {col.QuotedColumName}")); q.AppendLine(); } @@ -131,12 +149,6 @@ public override string BuildMoveDataSql( q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} OFF;"); } - var result = q.ToString(); - return result; - } - - protected override string GetExcludedColumnName(string columnName) - { - return $"SOURCE.{columnName}"; + return q.ToString(); } } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDialectBuilder.cs index 06e3406..0184904 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDialectBuilder.cs @@ -15,4 +15,6 @@ public override string CreateTableCopySql(string tempNameName, TableMetadata tab { return $"CREATE TEMP TABLE {tempNameName} AS SELECT * FROM {tableInfo.QuotedTableName} WHERE 0;"; } + + protected override string Trim(string lhs) => $"TRIM({lhs})"; } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs index 8103304..fa9826a 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/BulkInsertProviderBase.cs @@ -1,5 +1,4 @@ using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Storage; @@ -202,6 +201,7 @@ protected virtual async Task AddBulkInsertIdColumn( { var query = SqlDialect.BuildMoveDataSql( + context, tableInfo, tempTableName, tableInfo.GetColumns(options.CopyGeneratedColumns), diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs index 4fcf0ce..fcfa3aa 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs @@ -1,6 +1,8 @@ using System.Linq.Expressions; using System.Text; +using Microsoft.EntityFrameworkCore; + using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; @@ -8,17 +10,31 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Dialect; internal abstract class SqlDialectBuilder { + protected const string PseudoTableInserted = "INSERTED"; + protected const string PseudoTableExcluded = "EXCLUDED"; + protected abstract string OpenDelimiter { get; } protected abstract string CloseDelimiter { get; } protected virtual string ConcatOperator => "||"; + + /// + /// Indicates whether the dialect supports moving rows from temporary table to the final table, in order to + /// theoretically reduce disk space requirements. + /// protected virtual bool SupportsMoveRows => true; + /// + /// Indicates whether the dialect supports INSERT INTO table AS alias. + /// + protected virtual bool SupportsInsertIntoAlias => true; + public abstract string CreateTableCopySql(string tempNameName, TableMetadata tableInfo, IReadOnlyList columns); /// /// Builds the SQL for moving data from one table to another. /// + /// DB context /// Source table /// Target table name /// Columns to be inserted @@ -28,6 +44,7 @@ internal abstract class SqlDialectBuilder /// Entity type /// The SQL query public virtual string BuildMoveDataSql( + DbContext context, TableMetadata target, string source, IReadOnlyList insertedColumns, @@ -47,7 +64,14 @@ public virtual string BuildMoveDataSql( } // INSERT INTO {target} ({columns}) SELECT {columns} FROM {source} WHERE TRUE - q.Append($"INSERT INTO {target.QuotedTableName} ("); + q.Append($"INSERT INTO {target.QuotedTableName}"); + + if (SupportsInsertIntoAlias) + { + q.Append($" AS {PseudoTableInserted}"); + } + + q.AppendLine(" ("); q.AppendColumns(insertedColumns); q.AppendLine(")"); q.Append("SELECT "); @@ -66,13 +90,18 @@ public virtual string BuildMoveDataSql( if (onConflictTyped.Update != null) { q.Append(' '); - AppendOnConflictUpdate(q, GetUpdates(target, insertedColumns, onConflictTyped.Update)); + AppendOnConflictUpdate(q, GetUpdates(context, target, insertedColumns, onConflictTyped.Update)); } - if (onConflictTyped.Condition != null) + if (onConflictTyped.RawWhere != null || onConflictTyped.Where != null) { - q.Append(' '); - AppendConflictCondition(q, onConflictTyped); + if (onConflictTyped is { RawWhere: not null, Where: not null }) + { + throw new ArgumentException("Cannot specify both RawWhere and Where in OnConflictOptions."); + } + + q.Append(" WHERE "); + AppendConflictCondition(q, target, context, onConflictTyped); } } else @@ -91,8 +120,7 @@ public virtual string BuildMoveDataSql( q.AppendLine(";"); - var result = q.ToString(); - return result; + return q.ToString(); } protected virtual void AppendDoNothing(StringBuilder sql, IEnumerable insertedColumns) @@ -106,6 +134,8 @@ protected virtual void AppendOnConflictUpdate(StringBuilder sql, IEnumerable $"BTRIM({lhs})"; + protected virtual void AppendConflictMatch(StringBuilder sql, TableMetadata target, OnConflictOptions conflict) { if (conflict.Match != null) @@ -122,18 +152,32 @@ protected virtual void AppendOnConflictStatement(StringBuilder sql) sql.AppendLine("ON CONFLICT"); } - protected virtual void AppendConflictCondition(StringBuilder sql, OnConflictOptions onConflictTyped) + protected virtual void AppendConflictCondition(StringBuilder sql, TableMetadata target, DbContext context, + OnConflictOptions onConflictTyped) { - sql.AppendLine($"WHERE {onConflictTyped.Condition}"); + var condition = ""; + + if (onConflictTyped.RawWhere != null) + { + condition = onConflictTyped.RawWhere(PseudoTableInserted, PseudoTableExcluded); + } + else if (onConflictTyped.Where != null) + { + condition = ToSqlExpression(context, target, onConflictTyped.Where); + } + + sql.Append(condition).AppendLine(); } /// - /// Get the name of the excluded column for the ON CONFLICT clause. + /// Get the name of the INSERTED column (data already in the table) for the ON CONFLICT clause. /// - protected virtual string GetExcludedColumnName(string columnName) - { - return $"EXCLUDED.{Quote(columnName)}"; - } + protected virtual string GetInsertedColumnName(string columnName) => $"{PseudoTableInserted}.{Quote(columnName)}"; + + /// + /// Get the name of the EXCLUDED column (data conflicting with table) for the ON CONFLICT clause. + /// + protected virtual string GetExcludedColumnName(string columnName) => $"{PseudoTableExcluded}.{Quote(columnName)}"; /// /// Quotes a column name using database-specific delimiters. @@ -178,15 +222,20 @@ protected static string[] GetColumns(TableMetadata table, Expression e.Prop1); /// /// - protected IEnumerable GetUpdates(TableMetadata table, IEnumerable columns, Expression> update) + protected IEnumerable GetUpdates(DbContext context, TableMetadata table, IEnumerable columns, Expression> update) { + if (update is not LambdaExpression lambda) + { + throw new ArgumentException("Update expression must be a lambda expression."); + } + switch (update.Body) { case NewExpression { Members: not null } newExpr: { foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member))) { - yield return $"{table.GetColumnName(arg.member.Name)} = {ToSqlExpression(table, arg.expr)}"; + yield return $"{table.GetColumnName(arg.member.Name)} = {ToSqlExpression(context, table, arg.expr, lambda)}"; } break; @@ -195,14 +244,15 @@ protected IEnumerable GetUpdates(TableMetadata table, IEnumerable()) { - yield return $"{table.GetColumnName(binding.Member.Name)} = {ToSqlExpression(table, binding.Expression)}"; + yield return $"{table.GetColumnName(binding.Member.Name)} = {ToSqlExpression(context, table, binding.Expression, lambda)}"; } break; } case MemberExpression memberExpr: - yield return $"{table.GetColumnName(memberExpr.Member.Name)} = {ToSqlExpression(table, memberExpr)}"; + yield return $"{table.GetColumnName(memberExpr.Member.Name)} = {ToSqlExpression(context, table, memberExpr, lambda)}"; break; + case ParameterExpression parameterExpr when (parameterExpr.Type == typeof(T)): foreach (var property in columns) { @@ -219,17 +269,44 @@ protected IEnumerable GetUpdates(TableMetadata table, IEnumerable /// Converts an expression to an SQL string. /// + /// DB context /// The DbContext /// The expression, with simple operations + /// Current lambda expression /// Entity type /// An SQL statement /// Thrown when an expression could not be translated. - private string ToSqlExpression(TableMetadata table, Expression expr) + private string ToSqlExpression(DbContext context, TableMetadata table, Expression expr, LambdaExpression? lambda = null) { switch (expr) { + case LambdaExpression memberExpr: + return ToSqlExpression(context, table, memberExpr.Body, memberExpr); + case MemberExpression memberExpr: - return GetExcludedColumnName(table.GetColumnName(memberExpr.Member.Name)); + var columnName = table.GetColumnName(memberExpr.Member.Name); + + // If the member expression is a property of the current lambda + if (lambda is { Parameters.Count: > 1 } && memberExpr.Expression is ParameterExpression paramExpr) + { + if (paramExpr.Name == lambda.Parameters[0].Name) + { + return GetInsertedColumnName(columnName); + } + + if (paramExpr.Name == lambda.Parameters[1].Name) + { + return GetExcludedColumnName(columnName); + } + } + + return GetExcludedColumnName(columnName); + + case ConditionalExpression condExpr: + var test = ToSqlExpression(context, table, condExpr.Test, lambda); + var ifTrue = ToSqlExpression(context, table, condExpr.IfTrue, lambda); + var ifFalse = ToSqlExpression(context, table, condExpr.IfFalse, lambda); + return $"CASE WHEN {test} THEN {ifTrue} ELSE {ifFalse} END"; case BinaryExpression binaryExpr: { @@ -251,8 +328,8 @@ private string ToSqlExpression(TableMetadata table, Expression expr) _ => throw new NotSupportedException($"Unsupported operator: {binaryExpr.NodeType}") }; - var lhs = ToSqlExpression(table, binaryExpr.Left); - var rhs = ToSqlExpression(table, binaryExpr.Right); + var lhs = ToSqlExpression(context, table, binaryExpr.Left, lambda); + var rhs = ToSqlExpression(context, table, binaryExpr.Right, lambda); return $"({lhs} {op} {rhs})"; } @@ -279,17 +356,17 @@ private string ToSqlExpression(TableMetadata table, Expression expr) case UnaryExpression unaryExpr: if (unaryExpr.NodeType == ExpressionType.Convert) { - return ToSqlExpression(table, unaryExpr.Operand); + return ToSqlExpression(context, table, unaryExpr.Operand, lambda); } if (unaryExpr.NodeType == ExpressionType.Not) { - return $"NOT ({ToSqlExpression(table, unaryExpr.Operand)})"; + return $"NOT ({ToSqlExpression(context, table, unaryExpr.Operand, lambda)})"; } throw new NotSupportedException($"Unary operator not supported: {unaryExpr.NodeType}"); case MethodCallExpression methodExpr: { - var lhs = methodExpr.Object != null ? ToSqlExpression(table, methodExpr.Object) : null; + var lhs = methodExpr.Object != null ? ToSqlExpression(context, table, methodExpr.Object, lambda) : ""; switch (methodExpr.Method.Name) { @@ -298,13 +375,13 @@ private string ToSqlExpression(TableMetadata table, Expression expr) case "ToUpper": return $"UPPER({lhs})"; case "Trim": - return $"BTRIM({lhs})"; + return Trim(lhs); case "Contains" when methodExpr is { Object: not null, Arguments.Count: 1 }: - return $"{lhs} LIKE '%' || {ToSqlExpression(table, methodExpr.Arguments[0])} || '%'"; + return $"{lhs} LIKE '%' {ConcatOperator} {ToSqlExpression(context, table, methodExpr.Arguments[0], lambda)} {ConcatOperator} '%'"; case "EndsWith" when methodExpr is { Object: not null, Arguments.Count: 1 }: - return $"{lhs} LIKE '%' || {ToSqlExpression(table, methodExpr.Arguments[0])}"; + return $"{lhs} LIKE '%' {ConcatOperator} {ToSqlExpression(context, table, methodExpr.Arguments[0], lambda)}"; case "StartsWith" when methodExpr is { Object: not null, Arguments.Count: 1 }: - return $"{lhs} LIKE {ToSqlExpression(table, methodExpr.Arguments[0])} || '%'"; + return $"{lhs} LIKE {ToSqlExpression(context, table, methodExpr.Arguments[0], lambda)} {ConcatOperator} '%'"; default: throw new NotSupportedException($"Method not supported: {methodExpr.Method.Name}"); } diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs index 6650388..69407be 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/EnumerableDataReader.cs @@ -1,6 +1,5 @@ using System.Data; -using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; using PhenX.EntityFrameworkCore.BulkInsert.Metadata; using PhenX.EntityFrameworkCore.BulkInsert.Options; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/ColumnMetadata.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/ColumnMetadata.cs index 3d51957..30c9fe6 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/ColumnMetadata.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Metadata/ColumnMetadata.cs @@ -1,7 +1,6 @@ using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata; -using PhenX.EntityFrameworkCore.BulkInsert.Abstractions; using PhenX.EntityFrameworkCore.BulkInsert.Dialect; using PhenX.EntityFrameworkCore.BulkInsert.Options; diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Options/OnConflictOptions.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Options/OnConflictOptions.cs index c425b1e..5ef0c4f 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Options/OnConflictOptions.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Options/OnConflictOptions.cs @@ -8,9 +8,19 @@ namespace PhenX.EntityFrameworkCore.BulkInsert.Options; public abstract class OnConflictOptions { /// - /// Optional condition to apply on conflict. + /// Raw SQL condition delegate to match on conflict. /// - public string? Condition {get; set; } + public delegate string RawWhereDelegate(string insertedTable, string excludedTable); + + /// + /// Optional condition to apply on conflict, in raw SQL. + /// The table names provided as parameters can be used to reference data : + /// + /// insertedTable: refers to the data already in the target table. + /// excludedTable: refers to the new data, being in conflict. + /// + /// + public RawWhereDelegate? RawWhere { get; set; } } /// @@ -21,11 +31,25 @@ public class OnConflictOptions : OnConflictOptions { /// /// Columns to match on conflict. + /// + /// Match = (inserted) => new { inserted.Id } // Match on the Id column + /// /// public Expression>? Match { get; set; } /// /// Updates to apply on conflict. + /// + /// Update = (inserted, excluded) => new { inserted.Quantity = excluded.Quantity } // Update the Quantity column + /// + /// + public Expression>? Update { get; set; } + + /// + /// Condition to apply on conflict, with an expression. + /// + /// Where = (inserted, excluded) => inserted.Price > excluded.Price + /// /// - public Expression>? Update { get; set; } + public Expression>? Where { get; set; } } diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs index 6433f8e..cbfd721 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs @@ -52,7 +52,7 @@ public async Task InsertEntities_MultipleTimes(InsertStrategy strategy) var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, onConflict: new OnConflictOptions { - Update = e => e, + Update = (inserted, excluded) => inserted, }); // Assert @@ -82,7 +82,7 @@ public async Task InsertEntities_MultipleTimes_WithGuidId(InsertStrategy strateg var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, onConflict: new OnConflictOptions { - Update = e => e, + Update = (inserted, excluded) => inserted, }); // Assert @@ -112,7 +112,7 @@ public async Task InsertEntities_MultipleTimes_With_Conflict_On_Id(InsertStrateg o => o.CopyGeneratedColumns = true, onConflict: new OnConflictOptions { - Update = e => e, + Update = (inserted, excluded) => inserted, }); // Assert @@ -139,18 +139,16 @@ public async Task InsertEntities_WithConflict_SingleColumn(InsertStrategy strate }; // Act - var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, o => - { - o.MoveRows = true; - }, new OnConflictOptions + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => + {}, new OnConflictOptions { Match = e => new { e.Name, }, - Update = e => new TestEntity + Update = (inserted, excluded) => new TestEntity { - Name = e.Name + " - Conflict", + Name = inserted.Name + " - Conflict", }, }); @@ -179,10 +177,7 @@ public async Task InsertEntities_WithConflict_DoNothing(InsertStrategy strategy) }; // Act - var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, o => - { - o.MoveRows = true; - }, new OnConflictOptions + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions { Match = e => new { e.Name } }); @@ -196,7 +191,7 @@ public async Task InsertEntities_WithConflict_DoNothing(InsertStrategy strategy) [SkippableTheory] [InlineData(InsertStrategy.InsertReturn)] [InlineData(InsertStrategy.InsertReturnAsync)] - public async Task InsertEntities_WithConflict_Condition(InsertStrategy strategy) + public async Task InsertEntities_WithConflict_RawCondition(InsertStrategy strategy) { Skip.If(_context.IsProvider(ProviderType.MySql)); @@ -208,23 +203,106 @@ public async Task InsertEntities_WithConflict_Condition(InsertStrategy strategy) var entities = new List { new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 20 }, - new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 30 }, + new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 600 }, + }; + + await _context.ExecuteBulkInsertAsync(entities, onConflict: new OnConflictOptions + { + Match = e => new + { + e.Name, + } + }); + + // Act + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions + { + + Match = e => new { e.Name }, + Update = (inserted, excluded) => new TestEntity + { + Price = excluded.Price + inserted.Price, + }, + RawWhere = (insertedTable, excludedTable) => $"{excludedTable}.some_price != {insertedTable}.some_price", + }); + + // Assert + Assert.Single(insertedEntities); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 30); + } + + [SkippableTheory] + [InlineData(InsertStrategy.InsertReturn)] + [InlineData(InsertStrategy.InsertReturnAsync)] + public async Task InsertEntities_WithConflict_ExpressionCondition(InsertStrategy strategy) + { + Skip.If(_context.IsProvider(ProviderType.MySql)); + + // Arrange + _context.TestEntities.Add(new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 10 }); + _context.SaveChanges(); + _context.ChangeTracker.Clear(); + + var entities = new List + { + new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 20 }, + new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 600 }, }; + await _context.ExecuteBulkInsertAsync(entities, onConflict: new OnConflictOptions + { + Match = e => new + { + e.Name, + } + }); + // Act - var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, o => + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions { - o.MoveRows = true; - }, new OnConflictOptions + + Match = e => new { e.Name }, + Update = (inserted, excluded) => new TestEntity + { + Price = excluded.Price + inserted.Price, + }, + Where = (inserted, excluded) => excluded.Price != inserted.Price, + }); + + // Assert + Assert.Single(insertedEntities); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 30); + } + + [SkippableTheory] + [InlineData(InsertStrategy.InsertReturn)] + [InlineData(InsertStrategy.InsertReturnAsync)] + public async Task InsertEntities_WithConflict_ComplexExpressionCondition(InsertStrategy strategy) + { + Skip.If(_context.IsProvider(ProviderType.MySql)); + + // Arrange + _context.TestEntities.Add(new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 10 }); + _context.SaveChanges(); + _context.ChangeTracker.Clear(); + + var entities = new List + { + new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 20 }, + new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 30 }, + }; + + // Act + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions { Match = e => new { e.Name }, - Update = e => new TestEntity { Price = e.Price }, - Condition = "EXCLUDED.some_price > test_entity.some_price" + Update = (inserted, excluded) => new TestEntity { Price = (excluded.Price > 15 ? 15 : 10) }, + Where = (inserted, excluded) => excluded.Price > inserted.Price && inserted.Name.Trim().Contains("Entity1"), }); // Assert Assert.Equal(2, insertedEntities.Count); - Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 20); + Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 15); Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2" && e.Price == 30); } @@ -247,13 +325,10 @@ public async Task InsertEntities_WithConflict_MultipleColumns(InsertStrategy str }; // Act - var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, o => - { - o.MoveRows = true; - }, new OnConflictOptions + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions { Match = e => new { e.Name }, - Update = e => new TestEntity { Name = e.Name + " - Conflict", Price = 0 } + Update = (inserted, excluded) => new TestEntity { Name = inserted.Name + " - Conflict", Price = 0 } }); // Assert