Skip to content

Commit 23f5b4c

Browse files
Fix conflicts.
1 parent ad0d19c commit 23f5b4c

6 files changed

Lines changed: 166 additions & 20 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert.MySql/MySqlDialectBuilder.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
using System.Text;
2+
3+
using Microsoft.EntityFrameworkCore;
4+
using Microsoft.EntityFrameworkCore.Metadata;
5+
16
using PhenX.EntityFrameworkCore.BulkInsert.Dialect;
7+
using PhenX.EntityFrameworkCore.BulkInsert.Options;
28

39
namespace PhenX.EntityFrameworkCore.BulkInsert.MySql;
410

@@ -11,4 +17,43 @@ internal class MySqlServerDialectBuilder : SqlDialectBuilder
1117
protected override bool SupportsMoveRows => false;
1218

1319
public override bool SupportsReturning => false;
20+
21+
protected override void AppendConflictCondition<T>(StringBuilder sql, OnConflictOptions<T> onConflictTyped)
22+
{
23+
throw new NotSupportedException("Conflict conditions are not supported in MYSQL");
24+
}
25+
26+
protected override void AppendOnConflictUpdate(StringBuilder sql, IEnumerable<string> updates)
27+
{
28+
sql.AppendLine("UPDATE");
29+
30+
var i = 0;
31+
foreach (var update in updates)
32+
{
33+
if (i > 0)
34+
{
35+
sql.Append(", ");
36+
}
37+
38+
sql.Append(update);
39+
i++;
40+
}
41+
}
42+
43+
protected override void AppendOnConflictStatement(StringBuilder sql)
44+
{
45+
sql.Append("ON DUPLICATE KEY");
46+
}
47+
48+
protected override void AppendDoNothing(StringBuilder sql, IProperty[] insertedProperties)
49+
{
50+
var columnName = insertedProperties[0].GetColumnName();
51+
52+
sql.Append($"UPDATE {Quote(columnName)} = {GetExcludedColumnName(columnName)}");
53+
}
54+
55+
protected override string GetExcludedColumnName(string columnName)
56+
{
57+
return $"VALUES({Quote(columnName)})";
58+
}
1459
}

src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerDialectBuilder.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Linq.Expressions;
1+
using System.Linq.Expressions;
22
using System.Text;
33

44
using Microsoft.EntityFrameworkCore;
@@ -39,7 +39,7 @@ public override string BuildMoveDataSql<T>(DbContext context, string source,
3939
matchColumns.Select(col => $"TARGET.{col} = SOURCE.{col}"));
4040

4141
var updateSet = onConflictTyped.Update != null
42-
? string.Join(", ", GetUpdates(context, onConflictTyped.Update))
42+
? string.Join(", ", GetUpdates(context, insertedProperties, onConflictTyped.Update))
4343
: null;
4444

4545
q.AppendLine($"MERGE INTO {target} AS TARGET");
@@ -82,8 +82,8 @@ public override string BuildMoveDataSql<T>(DbContext context, string source,
8282
return q.ToString();
8383
}
8484

85-
protected override string GetExcludedColumnName<TEntity>(DbContext context, MemberExpression member)
85+
protected override string GetExcludedColumnName(string columnName)
8686
{
87-
return $"SOURCE.{GetColumnName<TEntity>(context, member.Member.Name)}";
87+
return $"SOURCE.{Quote(columnName)}";
8888
}
8989
}

src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using Microsoft.EntityFrameworkCore;
55
using Microsoft.EntityFrameworkCore.Metadata;
6+
using Microsoft.EntityFrameworkCore.Metadata.Internal;
67

78
using PhenX.EntityFrameworkCore.BulkInsert.Options;
89

@@ -88,28 +89,32 @@ WHERE TRUE
8889

8990
if (onConflict is OnConflictOptions<T> onConflictTyped)
9091
{
91-
q.AppendLine("ON CONFLICT");
92+
AppendOnConflictStatement(q);
9293

9394
if (onConflictTyped.Update != null)
9495
{
9596
if (onConflictTyped.Match != null)
9697
{
97-
q.AppendLine($"({string.Join(", ", GetColumns(context, onConflictTyped.Match))})");
98+
q.Append(' ');
99+
AppendConflictMatch(q, GetColumns(context, onConflictTyped.Match));
98100
}
99101

100102
if (onConflictTyped.Update != null)
101103
{
102-
q.AppendLine($"DO UPDATE SET {string.Join(", ", GetUpdates(context, onConflictTyped.Update))}");
104+
q.Append(' ');
105+
AppendOnConflictUpdate(q, GetUpdates(context, insertedProperties, onConflictTyped.Update));
103106
}
104107

105108
if (onConflictTyped.Condition != null)
106109
{
107-
q.AppendLine($"WHERE {onConflictTyped.Condition}");
110+
q.Append(' ');
111+
AppendConflictCondition(q, onConflictTyped);
108112
}
109113
}
110114
else
111115
{
112-
q.AppendLine("DO NOTHING");
116+
q.Append(' ');
117+
AppendDoNothing(q, insertedProperties);
113118
}
114119
}
115120

@@ -123,6 +128,57 @@ WHERE TRUE
123128
return q.ToString();
124129
}
125130

131+
protected virtual void AppendDoNothing(StringBuilder sql, IProperty[] insertedProperties)
132+
{
133+
sql.AppendLine("DO NOTHING");
134+
}
135+
136+
protected virtual void AppendOnConflictUpdate(StringBuilder sql, IEnumerable<string> updates)
137+
{
138+
sql.AppendLine("DO UPDATE SET");
139+
140+
var i = 0;
141+
foreach (var update in updates)
142+
{
143+
if (i > 0)
144+
{
145+
sql.Append(", ");
146+
}
147+
148+
sql.Append(update);
149+
i++;
150+
};
151+
}
152+
153+
protected virtual void AppendConflictMatch(StringBuilder sql, IEnumerable<string> columns)
154+
{
155+
sql.AppendLine("(");
156+
157+
var i = 0;
158+
foreach (var column in columns)
159+
{
160+
if (i > 0)
161+
{
162+
sql.Append(", ");
163+
}
164+
165+
sql.Append(column);
166+
i++;
167+
}
168+
169+
sql.AppendLine(")");
170+
}
171+
172+
protected virtual void AppendOnConflictStatement(StringBuilder sql)
173+
{
174+
sql.AppendLine("ON CONFLICT");
175+
}
176+
177+
protected virtual void AppendConflictCondition<T>(StringBuilder sql, OnConflictOptions<T> onConflictTyped)
178+
{
179+
sql.AppendLine($"WHERE {onConflictTyped.Condition}");
180+
}
181+
126182
/// <summary>
127183
/// Builds the SQL for selecting data from one table.
128184
/// </summary>
@@ -153,9 +209,9 @@ WHERE TRUE
153209
/// <summary>
154210
/// Get the name of the excluded column for the ON CONFLICT clause.
155211
/// </summary>
156-
protected virtual string GetExcludedColumnName<TEntity>(DbContext context, MemberExpression member)
212+
protected virtual string GetExcludedColumnName(string columnName)
157213
{
158-
return $"EXCLUDED.{GetColumnName<TEntity>(context, member.Member.Name)}";
214+
return $"EXCLUDED.{Quote(columnName)}";
159215
}
160216

161217
/// <summary>
@@ -201,7 +257,7 @@ protected string[] GetColumns<T>(DbContext context, Expression<Func<T, object>>
201257
/// var updates = GetUpdates(context, e => e.Prop1);
202258
/// </code>
203259
/// </example>
204-
protected IEnumerable<string> GetUpdates<T>(DbContext context, Expression<Func<T, object>> update)
260+
protected IEnumerable<string> GetUpdates<T>(DbContext context, IProperty[] properties, Expression<Func<T, object>> update)
205261
{
206262
switch (update.Body)
207263
{
@@ -226,8 +282,18 @@ protected IEnumerable<string> GetUpdates<T>(DbContext context, Expression<Func<T
226282
case MemberExpression memberExpr:
227283
yield return $"{GetColumnName<T>(context, memberExpr.Member.Name)} = {ToSqlExpression<T>(context, memberExpr)}";
228284
break;
285+
case ParameterExpression parameterExpr when (parameterExpr.Type == typeof(T)):
286+
foreach (var property in properties)
287+
{
288+
var columName = property.GetColumnName();
289+
290+
yield return $"{Quote(columName)} = {GetExcludedColumnName(columName)}";
291+
}
292+
293+
break;
294+
229295
default:
230-
throw new NotSupportedException("Unsupported expression type for update");
296+
throw new NotSupportedException($"Unsupported expression type {update.Body.GetType()} for update");
231297
}
232298
}
233299

@@ -244,7 +310,7 @@ private string ToSqlExpression<TEntity>(DbContext context, Expression expr)
244310
switch (expr)
245311
{
246312
case MemberExpression m:
247-
return GetExcludedColumnName<TEntity>(context, m);
313+
return GetExcludedColumnName(GetColumnName<TEntity>(context, m.Member.Name));
248314

249315
case BinaryExpression b:
250316
var left = ToSqlExpression<TEntity>(context, b.Left);

src/PhenX.EntityFrameworkCore.BulkInsert/Extensions/DbContextExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System.Data;
1+
using System.Data;
22
using System.Data.Common;
33

44
using Microsoft.EntityFrameworkCore;

src/PhenX.EntityFrameworkCore.BulkInsert/Options/BulkInsertOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace PhenX.EntityFrameworkCore.BulkInsert.Options;
1+
namespace PhenX.EntityFrameworkCore.BulkInsert.Options;
22

33
/// <summary>
44
/// Bulk insert general options.

tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Basic/BasicTestsBase.cs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,37 @@ public void InsertsEntitiesSuccessfully_Sync()
6464
Assert.Contains(insertedEntities, e => e.Name == $"{_prefix}_Entity2");
6565
}
6666

67+
[Fact]
68+
public async Task InsertsEntities_MultipleTimes()
69+
{
70+
// Arrange
71+
var entities = new List<TestEntity>
72+
{
73+
new TestEntity { Id = GetId(), Name = $"{_prefix}_Entity1" },
74+
new TestEntity { Id = GetId(), Name = $"{_prefix}_Entity2" }
75+
};
76+
77+
// Act
78+
await DbContainer.DbContext.ExecuteBulkInsertReturnEntitiesAsync(entities);
79+
80+
foreach (var entity in entities)
81+
{
82+
entity.NumericEnumValue = NumericEnum.Second;
83+
}
84+
85+
await DbContainer.DbContext.ExecuteBulkInsertReturnEntitiesAsync(entities,
86+
onConflict: new OnConflictOptions<TestEntity>
87+
{
88+
Update = e => e,
89+
});
90+
91+
// Assert
92+
var insertedEntities = DbContainer.DbContext.TestEntities.ToList();
93+
Assert.Equal(2, insertedEntities.Count);
94+
Assert.Contains(insertedEntities, e => e.NumericEnumValue == NumericEnum.Second);
95+
Assert.Contains(insertedEntities, e => e.NumericEnumValue == NumericEnum.Second);
96+
}
97+
6798
[Fact]
6899
public async Task InsertsEntitiesMoveRowsSuccessfully()
69100
{
@@ -87,9 +118,11 @@ await DbContainer.DbContext.ExecuteBulkInsertReturnEntitiesAsync(entities, o =>
87118
Assert.Contains(insertedEntities, e => e.Name == $"{_prefix}_Entity2");
88119
}
89120

90-
[Fact]
121+
[SkippableFact]
91122
public async Task InsertsEntitiesWithConflict_SingleColumn()
92123
{
124+
Skip.If(DbContainer.DbContext.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));
125+
93126
DbContainer.DbContext.TestEntities.Add(new TestEntity { Name = $"{_prefix}_Entity1" });
94127
await DbContainer.DbContext.SaveChangesAsync();
95128
DbContainer.DbContext.ChangeTracker.Clear();
@@ -155,7 +188,7 @@ await DbContainer.DbContext.ExecuteBulkInsertAsync(entities, o =>
155188
[SkippableFact]
156189
public async Task InsertsEntitiesWithConflict_Condition()
157190
{
158-
// Skip.If(DbContainer.DbContext.Database.ProviderName!.Contains("Npgsql", StringComparison.InvariantCultureIgnoreCase));
191+
Skip.If(DbContainer.DbContext.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));
159192

160193
DbContainer.DbContext.TestEntities.Add(new TestEntity { Name = $"{_prefix}_Entity1", Price = 10 });
161194
await DbContainer.DbContext.SaveChangesAsync();
@@ -183,9 +216,11 @@ await DbContainer.DbContext.ExecuteBulkInsertAsync(entities, o =>
183216
Assert.Contains(insertedEntities, e => e.Name == $"{_prefix}_Entity2" && e.Price == 30);
184217
}
185218

186-
[Fact]
219+
[SkippableFact]
187220
public async Task InsertsEntitiesWithConflict_MultipleColumns()
188221
{
222+
Skip.If(DbContainer.DbContext.Database.ProviderName!.Contains("Mysql", StringComparison.InvariantCultureIgnoreCase));
223+
189224
DbContainer.DbContext.TestEntities.Add(new TestEntity { Name = $"{_prefix}_Entity1", Price = 10 });
190225
await DbContainer.DbContext.SaveChangesAsync();
191226
DbContainer.DbContext.ChangeTracker.Clear();
@@ -255,7 +290,7 @@ await DbContainer.DbContext.ExecuteBulkInsertAsync(entities, o =>
255290
// Assert
256291
var insertedEntities = DbContainer.DbContext.TestEntities.ToList();
257292
Assert.Equal(count, insertedEntities.Count);
258-
Assert.Contains(insertedEntities, e => e.Name == $"{_prefix}_Entity1");
293+
Assert.Contains(insertedEntities, e => e.Name == "Entity1");
259294
Assert.Contains(insertedEntities, e => e.Name == "Entity" + count);
260295
}
261296

0 commit comments

Comments
 (0)