Skip to content

Commit 8016eed

Browse files
author
fabien.menager
committed
Add support for conflict resolution in bulk insert operations with PG
1 parent 56285ab commit 8016eed

11 files changed

Lines changed: 401 additions & 105 deletions

File tree

src/EntityFrameworkCore.ExecuteInsert.PostgreSql/PostgreSqlBulkInsertProvider.cs

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using System.Data.Common;
2+
using System.Linq.Expressions;
3+
using System.Text;
4+
5+
using EntityFrameworkCore.ExecuteInsert.OnConflict;
26

37
using Microsoft.EntityFrameworkCore;
8+
using Microsoft.EntityFrameworkCore.Metadata;
49

510
using Npgsql;
611

@@ -15,7 +20,7 @@ public class PostgreSqlBulkInsertProvider : BulkInsertProviderBase
1520
protected override string CreateTableCopySql => "CREATE TEMPORARY TABLE {0} AS TABLE {1} WITH NO DATA;";
1621

1722
//language=sql
18-
protected override string AddTableCopyBulkInsertId => "ALTER TABLE {0} ADD COLUMN _bulk_insert_id SERIAL PRIMARY KEY;";
23+
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD COLUMN {BulkInsertId} SERIAL PRIMARY KEY;";
1924

2025
private string GetBinaryImportCommand(DbContext context, Type entityType, string tableName)
2126
{
@@ -24,6 +29,204 @@ private string GetBinaryImportCommand(DbContext context, Type entityType, string
2429
return $"COPY {tableName} ({string.Join(", ", columns)}) FROM STDIN (FORMAT BINARY)";
2530
}
2631

32+
protected override string BuildInsertSelectQuery<T>(string tableName,
33+
string targetTableName,
34+
IProperty[] insertedProperties,
35+
IProperty[] properties,
36+
BulkInsertOptions options, OnConflictOptions? onConflict = null)
37+
{
38+
var insertedColumns = insertedProperties.Select(p => Escape(p.GetColumnName()));
39+
var insertedColumnList = string.Join(", ", insertedColumns);
40+
41+
var returnedColumns = properties.Select(p => $"{Escape(p.GetColumnName())} AS {Escape(p.Name)}");
42+
var columnList = string.Join(", ", returnedColumns);
43+
44+
var q = new StringBuilder();
45+
46+
if (options.MoveRows)
47+
{
48+
q.AppendLine($"""
49+
WITH moved_rows AS (
50+
DELETE FROM {tableName}
51+
RETURNING {insertedColumnList}
52+
)
53+
""");
54+
tableName = "moved_rows";
55+
}
56+
57+
q.AppendLine($"""
58+
INSERT INTO {targetTableName} ({insertedColumnList})
59+
SELECT {insertedColumnList}
60+
FROM {tableName}
61+
""");
62+
63+
if (onConflict is OnConflictOptions<T> onConflictTyped)
64+
{
65+
q.AppendLine("ON CONFLICT");
66+
67+
if (onConflictTyped.Update != null)
68+
{
69+
if (onConflictTyped.Match != null)
70+
{
71+
q.AppendLine($"({string.Join(", ", GetColumns(onConflictTyped.Match).Select(Escape))})");
72+
}
73+
74+
if (onConflictTyped.Update != null)
75+
{
76+
q.AppendLine($"DO UPDATE SET {string.Join(", ", GetUpdates(onConflictTyped.Update))}");
77+
}
78+
79+
if (onConflictTyped.Condition != null)
80+
{
81+
q.AppendLine($"WHERE {onConflictTyped.Condition}");
82+
}
83+
}
84+
else
85+
{
86+
q.AppendLine("DO NOTHING");
87+
}
88+
}
89+
90+
if (columnList.Length != 0)
91+
{
92+
q.AppendLine($"RETURNING {columnList}");
93+
}
94+
95+
q.AppendLine(";");
96+
97+
return q.ToString();
98+
}
99+
100+
private IEnumerable<string> GetUpdates<T>(Expression<Func<T, object>> update)
101+
{
102+
if (update.Body is NewExpression { Members: not null } newExpr)
103+
{
104+
foreach (var arg in newExpr.Arguments.Zip(newExpr.Members, (expr, member) => (expr, member)))
105+
{
106+
yield return $"{Escape(arg.member.Name)} = {ToSqlExpression(arg.expr)}";
107+
}
108+
}
109+
else if (update.Body is MemberInitExpression memberInit)
110+
{
111+
foreach (var binding in memberInit.Bindings.OfType<MemberAssignment>())
112+
{
113+
yield return $"{Escape(binding.Member.Name)} = {ToSqlExpression(binding.Expression)}";
114+
}
115+
}
116+
else if (update.Body is MemberExpression memberExpr)
117+
{
118+
yield return $"{Escape(memberExpr.Member.Name)} = {ToSqlExpression(memberExpr)}";
119+
}
120+
else
121+
{
122+
throw new NotSupportedException("Unsupported expression type for update");
123+
}
124+
}
125+
126+
private string ToSqlExpression(Expression expr)
127+
{
128+
switch (expr)
129+
{
130+
case MemberExpression m:
131+
var prefix = "EXCLUDED";
132+
return $"{prefix}.{Escape(m.Member.Name)}";
133+
134+
case BinaryExpression b:
135+
var left = ToSqlExpression(b.Left);
136+
var right = ToSqlExpression(b.Right);
137+
var op = b.NodeType switch
138+
{
139+
ExpressionType.Add => b.Type == typeof(string) ? "||" : "+",
140+
ExpressionType.Subtract => "-",
141+
ExpressionType.Multiply => "*",
142+
ExpressionType.Divide => "/",
143+
ExpressionType.Modulo => "%",
144+
ExpressionType.AndAlso => "AND",
145+
ExpressionType.OrElse => "OR",
146+
ExpressionType.Equal => "=",
147+
ExpressionType.NotEqual => "<>",
148+
ExpressionType.LessThan => "<",
149+
ExpressionType.LessThanOrEqual => "<=",
150+
ExpressionType.GreaterThan => ">",
151+
ExpressionType.GreaterThanOrEqual => ">=",
152+
_ => throw new NotSupportedException($"Opérateur non supporté: {b.NodeType}")
153+
};
154+
return $"({left} {op} {right})";
155+
156+
case ConstantExpression c:
157+
if (c.Type == typeof(RawSqlValue) && c.Value != null)
158+
{
159+
return ((RawSqlValue)c.Value!).Sql;
160+
}
161+
162+
if (c.Type == typeof(string) ||
163+
c.Type == typeof(Guid))
164+
{
165+
return $"'{c.Value}'";
166+
}
167+
168+
if (c.Type == typeof(bool))
169+
{
170+
return (bool)c.Value! ? "TRUE" : "FALSE";
171+
}
172+
173+
return c.Value?.ToString() ?? "NULL";
174+
175+
case UnaryExpression u:
176+
if (u.NodeType == ExpressionType.Convert)
177+
{
178+
return ToSqlExpression(u.Operand);
179+
}
180+
if (u.NodeType == ExpressionType.Not)
181+
{
182+
return $"NOT ({ToSqlExpression(u.Operand)})";
183+
}
184+
throw new NotSupportedException($"Unary operator not supported: {u.NodeType}");
185+
186+
case MethodCallExpression mce:
187+
// Supporte quelques méthodes courantes (ToLower, ToUpper, Trim, etc.)
188+
var objSql = mce.Object != null ? ToSqlExpression(mce.Object) : null;
189+
var argsSql = mce.Arguments.Select(ToSqlExpression).ToArray();
190+
switch (mce.Method.Name)
191+
{
192+
case "ToLower":
193+
return $"LOWER({objSql})";
194+
case "ToUpper":
195+
return $"UPPER({objSql})";
196+
case "Trim":
197+
return $"BTRIM({objSql})";
198+
case "Contains" when mce is { Object: not null, Arguments.Count: 1 }:
199+
return $"{objSql} LIKE '%' || {argsSql[0]} || '%'";
200+
case "StartsWith" when mce is { Object: not null, Arguments.Count: 1 }:
201+
return $"{objSql} LIKE {argsSql[0]} || '%'";
202+
case "EndsWith" when mce is { Object: not null, Arguments.Count: 1 }:
203+
return $"{objSql} LIKE '%' || {argsSql[0]}";
204+
default:
205+
throw new NotSupportedException($"Method not supported: {mce.Method.Name}");
206+
}
207+
208+
case ParameterExpression p:
209+
return Escape(p.Name ?? "param");
210+
211+
default:
212+
throw new NotSupportedException($"Expression not supported: {expr.NodeType}");
213+
}
214+
}
215+
216+
private string[] GetColumns<T>(Expression<Func<T, object>> columns)
217+
{
218+
return columns.Body switch
219+
{
220+
NewExpression newExpression => newExpression.Arguments.OfType<MemberExpression>()
221+
.Select(m => m.Member.Name)
222+
.ToArray(),
223+
MemberExpression memberExpression => [
224+
memberExpression.Member.Name
225+
],
226+
_ => throw new NotSupportedException("Unsupported expression type")
227+
};
228+
}
229+
27230
protected override async Task BulkImport<T>(DbContext context, DbConnection connection, IEnumerable<T> entities,
28231
string tableName, PropertyAccessor[] properties, CancellationToken ctk) where T : class
29232
{

src/EntityFrameworkCore.ExecuteInsert.SqlServer/SqlServerBulkInsertProvider.cs

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using System.Data;
22
using System.Data.Common;
3+
using System.Text;
34

45
using EntityFrameworkCore.ExecuteInsert.Helpers;
6+
using EntityFrameworkCore.ExecuteInsert.OnConflict;
57

68
using Microsoft.Data.SqlClient;
79
using Microsoft.EntityFrameworkCore;
@@ -18,12 +20,9 @@ public class SqlServerBulkInsertProvider : BulkInsertProviderBase
1820
protected override string CreateTableCopySql => "SELECT {2} INTO {0} FROM {1} WHERE 1 = 0;";
1921

2022
//language=sql
21-
protected override string AddTableCopyBulkInsertId => "ALTER TABLE {0} ADD _bulk_insert_id INT IDENTITY PRIMARY KEY;";
23+
protected override string AddTableCopyBulkInsertId => $"ALTER TABLE {{0}} ADD {BulkInsertId} INT IDENTITY PRIMARY KEY;";
2224

23-
protected override string GetTempTableName<T>(string tableName) where T : class
24-
{
25-
return $"#_temp_bulk_insert_{tableName}";
26-
}
25+
protected override string GetTempTableName(string tableName) => $"#_temp_bulk_insert_{tableName}";
2726

2827
protected override async Task BulkImport<T>(DbContext context, DbConnection connection, IEnumerable<T> entities, string tableName,
2928
PropertyAccessor[] properties, CancellationToken ctk)
@@ -45,35 +44,48 @@ protected override async Task BulkImport<T>(DbContext context, DbConnection conn
4544
await t.CommitAsync(ctk);
4645
}
4746

48-
protected override string BuildInsertSelectQuery(string tempTableName, string targetTableName,
49-
IProperty[] insertedProperties, IProperty[] properties, bool moveRows)
47+
protected override string BuildInsertSelectQuery<T>(string tableName,
48+
string targetTableName,
49+
IProperty[] insertedProperties,
50+
IProperty[] properties,
51+
BulkInsertOptions options, OnConflictOptions? onConflict = null)
5052
{
51-
var insertedColumns = insertedProperties.Select(p => Escape(p.GetColumnName()));
53+
var insertedColumns = insertedProperties.Select(p => Escape(p.GetColumnName())).ToArray();
5254
var insertedColumnList = string.Join(", ", insertedColumns);
53-
var columnList = string.Join(", ", properties.Select(p => $"INSERTED.{p.GetColumnName()}"));
5455

55-
// if (moveRows)
56+
var returnedColumns = properties.Select(p => $"INSERTED.{p.GetColumnName()} AS [{p.Name}]");
57+
var columnList = string.Join(", ", returnedColumns);
58+
59+
var q = new StringBuilder();
60+
61+
// if (options.MoveRows)
5662
// {
57-
// //language=sql
58-
// return $"""
59-
// WITH moved_rows AS (
60-
// DELETE FROM {tempTableName}
61-
// OUTPUT {insertedColumnList}
62-
// )
63-
// INSERT INTO {targetTableName} ({insertedColumnList})
64-
// SELECT {insertedColumnList}
65-
// FROM moved_rows
66-
// RETURNING {columnList};
67-
// """;
63+
// var deletedColumnList = string.Join(", ", insertedColumns.Select(c => $"DELETED.{c}"));
64+
//
65+
// q.AppendLine($"""
66+
// DELETE FROM {tableName}
67+
// OUTPUT {deletedColumnList}
68+
// """);
6869
// }
6970

70-
//language=sql
71-
return $"""
72-
INSERT INTO {targetTableName} ({insertedColumnList})
73-
OUTPUT {columnList}
74-
SELECT {insertedColumnList}
75-
FROM {tempTableName};
76-
""";
71+
q.AppendLine($"INSERT INTO {targetTableName} ({insertedColumnList})");
72+
73+
if (columnList.Length != 0)
74+
{
75+
q.AppendLine($"OUTPUT {columnList}");
76+
}
77+
78+
q.AppendLine($"""
79+
SELECT {insertedColumnList}
80+
FROM {tableName}
81+
""");
82+
83+
// SQL Server ne supporte pas ON CONFLICT DO NOTHING, mais on garde la signature pour homogénéité
84+
// if (options.OnConflictIgnore) { ... }
85+
86+
q.AppendLine(";");
87+
88+
return q.ToString();
7789
}
7890

7991
private DataTable ConvertToDataTable<T>(PropertyAccessor[] properties)

src/EntityFrameworkCore.ExecuteInsert/Abstractions/BulkInsertExtensions.cs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using Microsoft.EntityFrameworkCore;
1+
using EntityFrameworkCore.ExecuteInsert.OnConflict;
2+
3+
using Microsoft.EntityFrameworkCore;
24
using Microsoft.EntityFrameworkCore.Infrastructure;
35

46
namespace EntityFrameworkCore.ExecuteInsert.Abstractions;
@@ -9,23 +11,24 @@ public static async Task<List<T>> ExecuteInsertWithIdentityAsync<T>(
911
this DbSet<T> dbSet,
1012
IEnumerable<T> entities,
1113
Action<BulkInsertOptions>? configure = null,
14+
OnConflictOptions? onConflict = null,
1215
CancellationToken ctk = default
1316
) where T : class
1417
{
1518
var provider = InitProvider(dbSet, configure, out var context, out var options);
1619

17-
return await provider.BulkInsertWithIdentityAsync(context, entities, options, ctk);
20+
return await provider.BulkInsertWithIdentityAsync(context, entities, options, onConflict, ctk);
1821
}
1922

20-
public static async Task ExecuteInsertWithIdentityAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, CancellationToken cancellationToken = default) where T : class
23+
public static async Task ExecuteInsertWithIdentityAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class
2124
{
2225
var dbSet = dbContext.Set<T>();
2326
if (dbSet == null)
2427
{
2528
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
2629
}
2730

28-
await dbSet.ExecuteInsertWithIdentityAsync(entities, configure, cancellationToken);
31+
await dbSet.ExecuteInsertWithIdentityAsync(entities, configure, onConflict, cancellationToken);
2932
}
3033

3134
// public static async Task<List<object>> ExecuteInsertWithPrimaryKeyAsync<T>(
@@ -55,23 +58,24 @@ public static async Task ExecuteInsertAsync<T>(
5558
this DbSet<T> dbSet,
5659
IEnumerable<T> entities,
5760
Action<BulkInsertOptions>? configure = null,
61+
OnConflictOptions? onConflict = null,
5862
CancellationToken ctk = default
5963
) where T : class
6064
{
6165
var provider = InitProvider(dbSet, configure, out var context, out var options);
6266

63-
await provider.BulkInsertWithoutReturnAsync(context, entities, options, ctk);
67+
await provider.BulkInsertWithoutReturnAsync(context, entities, options, onConflict, ctk);
6468
}
6569

66-
public static async Task ExecuteInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, CancellationToken cancellationToken = default) where T : class
70+
public static async Task ExecuteInsertAsync<T>(this DbContext dbContext, IEnumerable<T> entities, Action<BulkInsertOptions>? configure = null, OnConflictOptions? onConflict = null, CancellationToken cancellationToken = default) where T : class
6771
{
6872
var dbSet = dbContext.Set<T>();
6973
if (dbSet == null)
7074
{
7175
throw new InvalidOperationException($"DbSet of type {typeof(T).Name} not found in DbContext.");
7276
}
7377

74-
await dbSet.ExecuteInsertAsync(entities, configure, cancellationToken);
78+
await dbSet.ExecuteInsertAsync(entities, configure, onConflict, cancellationToken);
7579
}
7680

7781
private static DbContext GetDbContext<T>(this DbSet<T> dbSet) where T : class

0 commit comments

Comments
 (0)