Skip to content

Commit 53e2ce3

Browse files
author
fabien.menager
committed
Implement MERGE statement support for SQL Server bulk insert operations with conflict resolution
1 parent 8016eed commit 53e2ce3

3 files changed

Lines changed: 189 additions & 155 deletions

File tree

src/EntityFrameworkCore.ExecuteInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -97,136 +97,6 @@ DELETE FROM {tableName}
9797
return q.ToString();
9898
}
9999

100-
private IEnumerable<string> GetUpdates<T>(Expression<Func<T, object>> update)
101-
{
102-
if (update.Body is NewExpression { Members: not null } newExpr)
103-
{
104-
foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member)))
105-
{
106-
yield return $"{Escape(arg.member.Name)} = {ToSqlExpression(arg.expr)}";
107-
}
108-
}
109-
else if (update.Body is MemberInitExpression memberInit)
110-
{
111-
foreach (var binding in memberInit.Bindings.OfType<MemberAssignment>())
112-
{
113-
yield return $"{Escape(binding.Member.Name)} = {ToSqlExpression(binding.Expression)}";
114-
}
115-
}
116-
else if (update.Body is MemberExpression memberExpr)
117-
{
118-
yield return $"{Escape(memberExpr.Member.Name)} = {ToSqlExpression(memberExpr)}";
119-
}
120-
else
121-
{
122-
throw new NotSupportedException("Unsupported expression type for update");
123-
}
124-
}
125-
126-
private string ToSqlExpression(Expression expr)
127-
{
128-
switch (expr)
129-
{
130-
case MemberExpression m:
131-
var prefix = "EXCLUDED";
132-
return $"{prefix}.{Escape(m.Member.Name)}";
133-
134-
case BinaryExpression b:
135-
var left = ToSqlExpression(b.Left);
136-
var right = ToSqlExpression(b.Right);
137-
var op = b.NodeType switch
138-
{
139-
ExpressionType.Add => b.Type == typeof(string) ? "||" : "+",
140-
ExpressionType.Subtract => "-",
141-
ExpressionType.Multiply => "*",
142-
ExpressionType.Divide => "/",
143-
ExpressionType.Modulo => "%",
144-
ExpressionType.AndAlso => "AND",
145-
ExpressionType.OrElse => "OR",
146-
ExpressionType.Equal => "=",
147-
ExpressionType.NotEqual => "<>",
148-
ExpressionType.LessThan => "<",
149-
ExpressionType.LessThanOrEqual => "<=",
150-
ExpressionType.GreaterThan => ">",
151-
ExpressionType.GreaterThanOrEqual => ">=",
152-
_ => throw new NotSupportedException($"Opérateur non supporté: {b.NodeType}")
153-
};
154-
return $"({left} {op} {right})";
155-
156-
case ConstantExpression c:
157-
if (c.Type == typeof(RawSqlValue) && c.Value != null)
158-
{
159-
return ((RawSqlValue)c.Value!).Sql;
160-
}
161-
162-
if (c.Type == typeof(string) ||
163-
c.Type == typeof(Guid))
164-
{
165-
return $"'{c.Value}'";
166-
}
167-
168-
if (c.Type == typeof(bool))
169-
{
170-
return (bool)c.Value! ? "TRUE" : "FALSE";
171-
}
172-
173-
return c.Value?.ToString() ?? "NULL";
174-
175-
case UnaryExpression u:
176-
if (u.NodeType == ExpressionType.Convert)
177-
{
178-
return ToSqlExpression(u.Operand);
179-
}
180-
if (u.NodeType == ExpressionType.Not)
181-
{
182-
return $"NOT ({ToSqlExpression(u.Operand)})";
183-
}
184-
throw new NotSupportedException($"Unary operator not supported: {u.NodeType}");
185-
186-
case MethodCallExpression mce:
187-
// Supporte quelques méthodes courantes (ToLower, ToUpper, Trim, etc.)
188-
var objSql = mce.Object != null ? ToSqlExpression(mce.Object) : null;
189-
var argsSql = mce.Arguments.Select(ToSqlExpression).ToArray();
190-
switch (mce.Method.Name)
191-
{
192-
case "ToLower":
193-
return $"LOWER({objSql})";
194-
case "ToUpper":
195-
return $"UPPER({objSql})";
196-
case "Trim":
197-
return $"BTRIM({objSql})";
198-
case "Contains" when mce is { Object: not null, Arguments.Count: 1 }:
199-
return $"{objSql} LIKE '%' || {argsSql[0]} || '%'";
200-
case "StartsWith" when mce is { Object: not null, Arguments.Count: 1 }:
201-
return $"{objSql} LIKE {argsSql[0]} || '%'";
202-
case "EndsWith" when mce is { Object: not null, Arguments.Count: 1 }:
203-
return $"{objSql} LIKE '%' || {argsSql[0]}";
204-
default:
205-
throw new NotSupportedException($"Method not supported: {mce.Method.Name}");
206-
}
207-
208-
case ParameterExpression p:
209-
return Escape(p.Name ?? "param");
210-
211-
default:
212-
throw new NotSupportedException($"Expression not supported: {expr.NodeType}");
213-
}
214-
}
215-
216-
private string[] GetColumns<T>(Expression<Func<T, object>> columns)
217-
{
218-
return columns.Body switch
219-
{
220-
NewExpression newExpression => newExpression.Arguments.OfType<MemberExpression>()
221-
.Select(m => m.Member.Name)
222-
.ToArray(),
223-
MemberExpression memberExpression => [
224-
memberExpression.Member.Name
225-
],
226-
_ => throw new NotSupportedException("Unsupported expression type")
227-
};
228-
}
229-
230100
protected override async Task BulkImport<T>(DbContext context, DbConnection connection, IEnumerable<T> entities,
231101
string tableName, PropertyAccessor[] properties, CancellationToken ctk) where T : class
232102
{

src/EntityFrameworkCore.ExecuteInsert.SqlServer/SqlServerBulkInsertProvider.cs

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using System.Data;
21
using System.Data.Common;
2+
using System.Linq.Expressions;
33
using System.Text;
44

55
using EntityFrameworkCore.ExecuteInsert.Helpers;
@@ -68,40 +68,62 @@ protected override string BuildInsertSelectQuery<T>(string tableName,
6868
// """);
6969
// }
7070

71-
q.AppendLine($"INSERT INTO {targetTableName} ({insertedColumnList})");
72-
73-
if (columnList.Length != 0)
71+
// Merge handling
72+
if (onConflict is OnConflictOptions<T> onConflictTyped && onConflictTyped.Match != null)
7473
{
75-
q.AppendLine($"OUTPUT {columnList}");
74+
var matchColumns = GetColumns(onConflictTyped.Match);
75+
var matchOn = string.Join(" AND ",
76+
matchColumns.Select(col => $"TARGET.{Escape(col)} = SOURCE.{Escape(col)}"));
77+
78+
var updateSet = onConflictTyped.Update != null
79+
? string.Join(", ", GetUpdates(onConflictTyped.Update))
80+
: null;
81+
82+
q.AppendLine($"MERGE INTO {targetTableName} AS TARGET");
83+
q.AppendLine(
84+
$"USING (SELECT {string.Join(", ", insertedColumns)} FROM {tableName}) AS SOURCE ({insertedColumnList})");
85+
q.AppendLine($"ON {matchOn}");
86+
87+
if (updateSet != null)
88+
{
89+
q.AppendLine($"WHEN MATCHED THEN UPDATE SET {updateSet}");
90+
}
91+
92+
q.AppendLine(
93+
$"WHEN NOT MATCHED THEN INSERT ({insertedColumnList}) VALUES ({string.Join(", ", insertedColumns.Select(c => $"SOURCE.{c}"))})");
94+
95+
if (columnList.Length != 0)
96+
{
97+
q.AppendLine($"OUTPUT {columnList}");
98+
}
7699
}
77100

78-
q.AppendLine($"""
79-
SELECT {insertedColumnList}
80-
FROM {tableName}
81-
""");
101+
// No conflict handling
102+
else
103+
{
104+
q.AppendLine($"INSERT INTO {targetTableName} ({insertedColumnList})");
105+
106+
if (columnList.Length != 0)
107+
{
108+
q.AppendLine($"OUTPUT {columnList}");
109+
}
82110

83-
// SQL Server ne supporte pas ON CONFLICT DO NOTHING, mais on garde la signature pour homogénéité
84-
// if (options.OnConflictIgnore) { ... }
111+
q.AppendLine($"""
112+
SELECT {insertedColumnList}
113+
FROM {tableName}
114+
""");
115+
}
85116

86117
q.AppendLine(";");
87118

88119
return q.ToString();
89120
}
90121

91-
private DataTable ConvertToDataTable<T>(PropertyAccessor[] properties)
122+
protected override string GetExcludedColumnName(MemberExpression member)
92123
{
93-
var dataTable = new DataTable(typeof(T).Name);
94-
95-
if (properties.Length == 0)
96-
{
97-
throw new InvalidOperationException($"No properties found for type {typeof(T).Name}");
98-
}
99-
100-
foreach (var prop in properties)
101-
{
102-
dataTable.Columns.Add(prop.Name, prop.ProviderClrType);
103-
}
104-
105-
return dataTable;
124+
var prefix = "SOURCE";
125+
return $"{prefix}.{Escape(member.Member.Name)}";
106126
}
127+
128+
protected override string ConcatOperator => "+";
107129
}

src/EntityFrameworkCore.ExecuteInsert/BulkInsertProviderBase.cs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections;
22
using System.Data;
33
using System.Data.Common;
4+
using System.Linq.Expressions;
45
using System.Reflection;
56

67
using EntityFrameworkCore.ExecuteInsert.Abstractions;
@@ -367,4 +368,145 @@ protected string GetEscapedTableName(DbContext context, Type entityType)
367368
{
368369
return DatabaseHelper.GetEscapedTableName(context, entityType, OpenDelimiter, CloseDelimiter);
369370
}
371+
372+
protected IEnumerable<string> GetUpdates<T>(Expression<Func<T, object>> update)
373+
{
374+
switch (update.Body)
375+
{
376+
case NewExpression { Members: not null } newExpr:
377+
{
378+
foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member)))
379+
{
380+
yield return $"{Escape(arg.member.Name)} = {ToSqlExpression(arg.expr)}";
381+
}
382+
383+
break;
384+
}
385+
case MemberInitExpression memberInit:
386+
{
387+
foreach (var binding in memberInit.Bindings.OfType<MemberAssignment>())
388+
{
389+
yield return $"{Escape(binding.Member.Name)} = {ToSqlExpression(binding.Expression)}";
390+
}
391+
392+
break;
393+
}
394+
case MemberExpression memberExpr:
395+
yield return $"{Escape(memberExpr.Member.Name)} = {ToSqlExpression(memberExpr)}";
396+
break;
397+
default:
398+
throw new NotSupportedException("Unsupported expression type for update");
399+
}
400+
}
401+
402+
protected virtual string ConcatOperator => "||";
403+
404+
protected virtual string GetExcludedColumnName(MemberExpression member)
405+
{
406+
var prefix = "EXCLUDED";
407+
return $"{prefix}.{Escape(member.Member.Name)}";
408+
}
409+
410+
private string ToSqlExpression(Expression expr)
411+
{
412+
switch (expr)
413+
{
414+
case MemberExpression m:
415+
return GetExcludedColumnName(m);
416+
417+
case BinaryExpression b:
418+
var left = ToSqlExpression(b.Left);
419+
var right = ToSqlExpression(b.Right);
420+
var op = b.NodeType switch
421+
{
422+
ExpressionType.Add => b.Type == typeof(string) ? ConcatOperator : "+",
423+
ExpressionType.Subtract => "-",
424+
ExpressionType.Multiply => "*",
425+
ExpressionType.Divide => "/",
426+
ExpressionType.Modulo => "%",
427+
ExpressionType.AndAlso => "AND",
428+
ExpressionType.OrElse => "OR",
429+
ExpressionType.Equal => "=",
430+
ExpressionType.NotEqual => "<>",
431+
ExpressionType.LessThan => "<",
432+
ExpressionType.LessThanOrEqual => "<=",
433+
ExpressionType.GreaterThan => ">",
434+
ExpressionType.GreaterThanOrEqual => ">=",
435+
_ => throw new NotSupportedException($"Unsupported operator: {b.NodeType}")
436+
};
437+
return $"({left} {op} {right})";
438+
439+
case ConstantExpression c:
440+
if (c.Type == typeof(RawSqlValue) && c.Value != null)
441+
{
442+
return ((RawSqlValue)c.Value!).Sql;
443+
}
444+
445+
if (c.Type == typeof(string) ||
446+
c.Type == typeof(Guid))
447+
{
448+
return $"'{c.Value}'";
449+
}
450+
451+
if (c.Type == typeof(bool))
452+
{
453+
return (bool)c.Value! ? "TRUE" : "FALSE";
454+
}
455+
456+
return c.Value?.ToString() ?? "NULL";
457+
458+
case UnaryExpression u:
459+
if (u.NodeType == ExpressionType.Convert)
460+
{
461+
return ToSqlExpression(u.Operand);
462+
}
463+
if (u.NodeType == ExpressionType.Not)
464+
{
465+
return $"NOT ({ToSqlExpression(u.Operand)})";
466+
}
467+
throw new NotSupportedException($"Unary operator not supported: {u.NodeType}");
468+
469+
case MethodCallExpression mce:
470+
// Supporte quelques méthodes courantes (ToLower, ToUpper, Trim, etc.)
471+
var objSql = mce.Object != null ? ToSqlExpression(mce.Object) : null;
472+
var argsSql = mce.Arguments.Select(ToSqlExpression).ToArray();
473+
switch (mce.Method.Name)
474+
{
475+
case "ToLower":
476+
return $"LOWER({objSql})";
477+
case "ToUpper":
478+
return $"UPPER({objSql})";
479+
case "Trim":
480+
return $"BTRIM({objSql})";
481+
case "Contains" when mce is { Object: not null, Arguments.Count: 1 }:
482+
return $"{objSql} LIKE '%' || {argsSql[0]} || '%'";
483+
case "StartsWith" when mce is { Object: not null, Arguments.Count: 1 }:
484+
return $"{objSql} LIKE {argsSql[0]} || '%'";
485+
case "EndsWith" when mce is { Object: not null, Arguments.Count: 1 }:
486+
return $"{objSql} LIKE '%' || {argsSql[0]}";
487+
default:
488+
throw new NotSupportedException($"Method not supported: {mce.Method.Name}");
489+
}
490+
491+
case ParameterExpression p:
492+
return Escape(p.Name ?? "param");
493+
494+
default:
495+
throw new NotSupportedException($"Expression not supported: {expr.NodeType}");
496+
}
497+
}
498+
499+
protected string[] GetColumns<T>(Expression<Func<T, object>> columns)
500+
{
501+
return columns.Body switch
502+
{
503+
NewExpression newExpression => newExpression.Arguments.OfType<MemberExpression>()
504+
.Select(m => m.Member.Name)
505+
.ToArray(),
506+
MemberExpression memberExpression => [
507+
memberExpression.Member.Name
508+
],
509+
_ => throw new NotSupportedException("Unsupported expression type")
510+
};
511+
}
370512
}

0 commit comments

Comments
 (0)