Skip to content

Commit 7154c6e

Browse files
author
fabien.menager
committed
Fix upsert for SQL server and add "excluded" to the update lambda
1 parent 960b33c commit 7154c6e

6 files changed

Lines changed: 92 additions & 37 deletions

File tree

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,10 @@ await dbContext.ExecuteBulkInsertAsync(entities, onConflict: new OnConflictOptio
128128
},
129129

130130
// Optional: specify the update action, if not specified, the default action is to do nothing
131-
Update = e => new TestEntity
131+
// Excluded is the row being inserted which is in conflict, and Inserted is the row already in the database.
132+
Update = (inserted, excluded) => new TestEntity
132133
{
133-
Price = e.Price // Update the Price column with the new value
134+
Price = inserted.Price // Update the Price column with the new value
134135
},
135136

136137
// Optional: specify the condition for the update action

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ public override string CreateTableCopySql(string tempTableName, TableMetadata ta
3636
return q.ToString();
3737
}
3838

39+
protected override string Trim(string lhs) => $"TRIM({lhs})";
40+
3941
public override string BuildMoveDataSql<T>(
4042
DbContext context,
4143
TableMetadata target,
@@ -87,9 +89,15 @@ public override string BuildMoveDataSql<T>(
8789

8890
q.AppendLine("WHEN MATCHED ");
8991

90-
if (!string.IsNullOrEmpty(onConflictTyped.RawWhere))
92+
if (onConflictTyped.RawWhere != null || onConflictTyped.Where != null)
9193
{
92-
q.Append($"AND {onConflictTyped.RawWhere} ");
94+
if (onConflictTyped is { RawWhere: not null, Where: not null })
95+
{
96+
throw new ArgumentException("Cannot specify both RawWhere and Where in OnConflictOptions.");
97+
}
98+
99+
q.Append("AND ");
100+
AppendConflictCondition(q, target, context, onConflictTyped);
93101
}
94102

95103
q.AppendLine("THEN UPDATE SET ");
@@ -141,7 +149,6 @@ public override string BuildMoveDataSql<T>(
141149
q.AppendLine($"SET IDENTITY_INSERT {target.QuotedTableName} OFF;");
142150
}
143151

144-
var result = q.ToString();
145-
return result;
152+
return q.ToString();
146153
}
147154
}

src/PhenX.EntityFrameworkCore.BulkInsert.Sqlite/SqliteDialectBuilder.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,5 @@ public override string CreateTableCopySql(string tempNameName, TableMetadata tab
1616
return $"CREATE TEMP TABLE {tempNameName} AS SELECT * FROM {tableInfo.QuotedTableName} WHERE 0;";
1717
}
1818

19-
protected override string Trim(string lhs)
20-
{
21-
return $"TRIM({lhs})";
22-
}
19+
protected override string Trim(string lhs) => $"TRIM({lhs})";
2320
}

src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public virtual string BuildMoveDataSql<T>(
8888
throw new ArgumentException("Cannot specify both RawWhere and Where in OnConflictOptions.");
8989
}
9090

91-
q.Append(' ');
91+
q.Append(" WHERE ");
9292
AppendConflictCondition(q, target, context, onConflictTyped);
9393
}
9494
}
@@ -150,7 +150,7 @@ protected virtual void AppendConflictCondition<T>(StringBuilder sql, TableMetada
150150
condition = ToSqlExpression<T>(context, target, onConflictTyped.Where);
151151
}
152152

153-
sql.AppendLine($"WHERE {condition}");
153+
sql.Append(condition).AppendLine();
154154
}
155155

156156
/// <summary>
@@ -206,15 +206,20 @@ protected static string[] GetColumns<T>(TableMetadata table, Expression<Func<T,
206206
/// var updates = GetUpdates(context, e => e.Prop1);
207207
/// </code>
208208
/// </example>
209-
protected IEnumerable<string> GetUpdates<T>(DbContext context, TableMetadata table, IEnumerable<ColumnMetadata> columns, Expression<Func<T, object>> update)
209+
protected IEnumerable<string> GetUpdates<T>(DbContext context, TableMetadata table, IEnumerable<ColumnMetadata> columns, Expression<Func<T, T, object>> update)
210210
{
211+
if (update is not LambdaExpression lambda)
212+
{
213+
throw new ArgumentException("Update expression must be a lambda expression.");
214+
}
215+
211216
switch (update.Body)
212217
{
213218
case NewExpression { Members: not null } newExpr:
214219
{
215220
foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member)))
216221
{
217-
yield return $"{table.GetColumnName(arg.member.Name)} = {ToSqlExpression<T>(context, table, arg.expr)}";
222+
yield return $"{table.GetColumnName(arg.member.Name)} = {ToSqlExpression<T>(context, table, arg.expr, lambda)}";
218223
}
219224

220225
break;
@@ -223,14 +228,15 @@ protected IEnumerable<string> GetUpdates<T>(DbContext context, TableMetadata tab
223228
{
224229
foreach (var binding in memberInit.Bindings.OfType<MemberAssignment>())
225230
{
226-
yield return $"{table.GetColumnName(binding.Member.Name)} = {ToSqlExpression<T>(context, table, binding.Expression)}";
231+
yield return $"{table.GetColumnName(binding.Member.Name)} = {ToSqlExpression<T>(context, table, binding.Expression, lambda)}";
227232
}
228233

229234
break;
230235
}
231236
case MemberExpression memberExpr:
232-
yield return $"{table.GetColumnName(memberExpr.Member.Name)} = {ToSqlExpression<T>(context, table, memberExpr)}";
237+
yield return $"{table.GetColumnName(memberExpr.Member.Name)} = {ToSqlExpression<T>(context, table, memberExpr, lambda)}";
233238
break;
239+
234240
case ParameterExpression parameterExpr when (parameterExpr.Type == typeof(T)):
235241
foreach (var property in columns)
236242
{
@@ -355,11 +361,11 @@ private string ToSqlExpression<TEntity>(DbContext context, TableMetadata table,
355361
case "Trim":
356362
return Trim(lhs);
357363
case "Contains" when methodExpr is { Object: not null, Arguments.Count: 1 }:
358-
return $"{lhs} LIKE '%' || {ToSqlExpression<TEntity>(context, table, methodExpr.Arguments[0], lambda)} || '%'";
364+
return $"{lhs} LIKE '%' {ConcatOperator} {ToSqlExpression<TEntity>(context, table, methodExpr.Arguments[0], lambda)} {ConcatOperator} '%'";
359365
case "EndsWith" when methodExpr is { Object: not null, Arguments.Count: 1 }:
360-
return $"{lhs} LIKE '%' || {ToSqlExpression<TEntity>(context, table, methodExpr.Arguments[0], lambda)}";
366+
return $"{lhs} LIKE '%' {ConcatOperator} {ToSqlExpression<TEntity>(context, table, methodExpr.Arguments[0], lambda)}";
361367
case "StartsWith" when methodExpr is { Object: not null, Arguments.Count: 1 }:
362-
return $"{lhs} LIKE {ToSqlExpression<TEntity>(context, table, methodExpr.Arguments[0], lambda)} || '%'";
368+
return $"{lhs} LIKE {ToSqlExpression<TEntity>(context, table, methodExpr.Arguments[0], lambda)} {ConcatOperator} '%'";
363369
default:
364370
throw new NotSupportedException($"Method not supported: {methodExpr.Method.Name}");
365371
}

src/PhenX.EntityFrameworkCore.BulkInsert/Options/OnConflictOptions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ public class OnConflictOptions<T> : OnConflictOptions
3333
/// <summary>
3434
/// Updates to apply on conflict.
3535
/// <example><code>
36-
/// Update = (inserted) => new { inserted.Quantity + 1 } // Increment the quantity by 1 on conflict
36+
/// Update = (inserted, excluded) => new { inserted.Quantity = excluded.Quantity } // Update the Quantity column
3737
/// </code></example>
3838
/// </summary>
39-
public Expression<Func<T, object>>? Update { get; set; }
39+
public Expression<Func<T, T, object>>? Update { get; set; }
4040

4141
/// <summary>
4242
/// Condition to apply on conflict, with an expression.

tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public async Task InsertEntities_MultipleTimes(InsertStrategy strategy)
5252
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities,
5353
onConflict: new OnConflictOptions<TestEntity>
5454
{
55-
Update = e => e,
55+
Update = (inserted, excluded) => inserted,
5656
});
5757

5858
// Assert
@@ -82,7 +82,7 @@ public async Task InsertEntities_MultipleTimes_WithGuidId(InsertStrategy strateg
8282
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities,
8383
onConflict: new OnConflictOptions<TestEntityWithGuidId>
8484
{
85-
Update = e => e,
85+
Update = (inserted, excluded) => inserted,
8686
});
8787

8888
// Assert
@@ -112,7 +112,7 @@ public async Task InsertEntities_MultipleTimes_With_Conflict_On_Id(InsertStrateg
112112
o => o.CopyGeneratedColumns = true,
113113
onConflict: new OnConflictOptions<TestEntity>
114114
{
115-
Update = e => e,
115+
Update = (inserted, excluded) => inserted,
116116
});
117117

118118
// Assert
@@ -146,9 +146,9 @@ public async Task InsertEntities_WithConflict_SingleColumn(InsertStrategy strate
146146
{
147147
e.Name,
148148
},
149-
Update = e => new TestEntity
149+
Update = (inserted, excluded) => new TestEntity
150150
{
151-
Name = e.Name + " - Conflict",
151+
Name = inserted.Name + " - Conflict",
152152
},
153153
});
154154

@@ -203,15 +203,14 @@ public async Task InsertEntities_WithConflict_RawCondition(InsertStrategy strate
203203
var entities = new List<TestEntity>
204204
{
205205
new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 20 },
206-
new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 30 },
206+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 600 },
207207
};
208208

209209
await _context.ExecuteBulkInsertAsync(entities, onConflict: new OnConflictOptions<TestEntity>
210210
{
211211
Match = e => new
212212
{
213213
e.Name,
214-
// ...other columns to match on
215214
}
216215
});
217216

@@ -220,14 +219,16 @@ public async Task InsertEntities_WithConflict_RawCondition(InsertStrategy strate
220219
{
221220

222221
Match = e => new { e.Name },
223-
Update = e => new TestEntity { Price = e.Price },
224-
RawWhere = "EXCLUDED.some_price > INSERTED.some_price"
222+
Update = (inserted, excluded) => new TestEntity
223+
{
224+
Price = excluded.Price + inserted.Price,
225+
},
226+
RawWhere = "EXCLUDED.some_price != INSERTED.some_price"
225227
});
226228

227229
// Assert
228-
Assert.Equal(2, insertedEntities.Count);
229-
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 20);
230-
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2" && e.Price == 30);
230+
Assert.Single(insertedEntities);
231+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 30);
231232
}
232233

233234
[SkippableTheory]
@@ -242,6 +243,49 @@ public async Task InsertEntities_WithConflict_ExpressionCondition(InsertStrategy
242243
_context.SaveChanges();
243244
_context.ChangeTracker.Clear();
244245

246+
var entities = new List<TestEntity>
247+
{
248+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 20 },
249+
new TestEntity { TestRun = _run, Name = $"{_run}_Entity2", Price = 600 },
250+
};
251+
252+
await _context.ExecuteBulkInsertAsync(entities, onConflict: new OnConflictOptions<TestEntity>
253+
{
254+
Match = e => new
255+
{
256+
e.Name,
257+
}
258+
});
259+
260+
// Act
261+
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions<TestEntity>
262+
{
263+
264+
Match = e => new { e.Name },
265+
Update = (inserted, excluded) => new TestEntity
266+
{
267+
Price = excluded.Price + inserted.Price,
268+
},
269+
Where = (inserted, excluded) => excluded.Price != inserted.Price,
270+
});
271+
272+
// Assert
273+
Assert.Single(insertedEntities);
274+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 30);
275+
}
276+
277+
[SkippableTheory]
278+
[InlineData(InsertStrategy.InsertReturn)]
279+
[InlineData(InsertStrategy.InsertReturnAsync)]
280+
public async Task InsertEntities_WithConflict_ComplexExpressionCondition(InsertStrategy strategy)
281+
{
282+
Skip.If(_context.IsProvider(ProviderType.MySql));
283+
284+
// Arrange
285+
_context.TestEntities.Add(new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 10 });
286+
_context.SaveChanges();
287+
_context.ChangeTracker.Clear();
288+
245289
var entities = new List<TestEntity>
246290
{
247291
new TestEntity { TestRun = _run, Name = $"{_run}_Entity1", Price = 20 },
@@ -252,13 +296,13 @@ public async Task InsertEntities_WithConflict_ExpressionCondition(InsertStrategy
252296
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions<TestEntity>
253297
{
254298
Match = e => new { e.Name },
255-
Update = e => new TestEntity { Price = e.Price },
256-
Where = (inserted, excluded) => excluded.Price > inserted.Price && excluded.Price > 15 ? inserted.Name.Trim().Contains("Entity1") : inserted.Name.Trim().Contains("Entity2"),
299+
Update = (inserted, excluded) => new TestEntity { Price = (excluded.Price > 15 ? 15 : 10) },
300+
Where = (inserted, excluded) => excluded.Price > inserted.Price && inserted.Name.Trim().Contains("Entity1"),
257301
});
258302

259303
// Assert
260304
Assert.Equal(2, insertedEntities.Count);
261-
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 20);
305+
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1" && e.Price == 15);
262306
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2" && e.Price == 30);
263307
}
264308

@@ -284,7 +328,7 @@ public async Task InsertEntities_WithConflict_MultipleColumns(InsertStrategy str
284328
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities, _ => {}, new OnConflictOptions<TestEntity>
285329
{
286330
Match = e => new { e.Name },
287-
Update = e => new TestEntity { Name = e.Name + " - Conflict", Price = 0 }
331+
Update = (inserted, excluded) => new TestEntity { Name = inserted.Name + " - Conflict", Price = 0 }
288332
});
289333

290334
// Assert

0 commit comments

Comments
 (0)