diff --git a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs index 4dd19cb..cde3d67 100644 --- a/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs +++ b/src/PhenX.EntityFrameworkCore.BulkInsert/Dialect/SqlDialectBuilder.cs @@ -253,9 +253,9 @@ protected IEnumerable GetUpdates(DbContext context, TableMetadata tab } case MemberInitExpression memberInit: { - foreach (var binding in memberInit.Bindings.OfType()) + foreach (var updateSql in GetUpdatesFromMemberInit(context, table, memberInit, lambda)) { - yield return $"{table.GetQuotedColumnName(binding.Member.Name)} = {ToSqlExpression(context, table, binding.Expression, lambda)}"; + yield return updateSql; } break; @@ -297,15 +297,20 @@ private string ToSqlExpression(DbContext context, TableMetadata table, case MemberExpression memberExpr: var columnName = table.GetColumnName(memberExpr.Member.Name); + // Traverse up the expression chain to find the root parameter + // This handles both simple properties (e.g., excluded.Name) and + // complex properties (e.g., excluded.ComplexObject.Property) + var rootParam = GetRootParameter(memberExpr); + // If the member expression is a property of the current lambda - if (lambda is { Parameters.Count: > 1 } && memberExpr.Expression is ParameterExpression paramExpr) + if (lambda is { Parameters.Count: > 1 } && rootParam != null) { - if (paramExpr.Name == lambda.Parameters[0].Name) + if (rootParam.Name == lambda.Parameters[0].Name) { return GetInsertedColumnName(columnName); } - if (paramExpr.Name == lambda.Parameters[1].Name) + if (rootParam.Name == lambda.Parameters[1].Name) { return GetExcludedColumnName(columnName); } @@ -405,4 +410,64 @@ private string ToSqlExpression(DbContext context, TableMetadata table, throw new NotSupportedException($"Expression not supported: {expr.NodeType}"); } } + + /// + /// Extracts update SQL statements from a MemberInitExpression, handling both simple properties + /// and nested complex property initializations recursively. + /// + /// DB context + /// Table metadata + /// The member initialization expression + /// Current lambda expression + /// Entity type + /// SQL update statements for each property assignment + private IEnumerable GetUpdatesFromMemberInit(DbContext context, TableMetadata table, MemberInitExpression memberInit, LambdaExpression lambda) + { + foreach (var binding in memberInit.Bindings.OfType()) + { + // Check if the binding expression is a nested MemberInitExpression (complex property assignment) + if (binding.Expression is MemberInitExpression nestedMemberInit) + { + // Recursively process nested complex property assignments to handle arbitrary nesting levels + foreach (var update in GetUpdatesFromMemberInit(context, table, nestedMemberInit, lambda)) + { + yield return update; + } + } + else + { + // Simple property assignment - the column name is the property name + yield return $"{table.GetQuotedColumnName(binding.Member.Name)} = {ToSqlExpression(context, table, binding.Expression, lambda)}"; + } + } + } + + /// + /// Traverses up a member expression chain to find the root parameter expression. + /// This handles both simple properties (e.g., excluded.Name) and complex properties (e.g., excluded.ComplexObject.Property). + /// + /// The member expression to traverse. + /// The root parameter expression if found; otherwise, null if the expression chain doesn't contain a parameter. + private static ParameterExpression? GetRootParameter(MemberExpression memberExpr) + { + Expression? current = memberExpr.Expression; + while (current != null) + { + if (current is ParameterExpression param) + { + return param; + } + + if (current is MemberExpression nested) + { + current = nested.Expression; + } + else + { + break; + } + } + + return null; + } } diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/DbContext/TestSmartEnum.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/DbContext/TestSmartEnum.cs index 02aad51..4d3c4a8 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/DbContext/TestSmartEnum.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/DbContext/TestSmartEnum.cs @@ -8,5 +8,5 @@ private TestSmartEnum(string name, int value) : base(name, value) { } - public static readonly TestSmartEnum Value = new TestSmartEnum("test", 1); + public static readonly TestSmartEnum Test = new TestSmartEnum("test", 1); } diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs index 57076b4..49f02ec 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Merge/MergeTestsBase.cs @@ -348,4 +348,170 @@ public async Task InsertEntities_WithConflict_MultipleColumns(InsertStrategy str Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2"); Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1 - Conflict" && e.Price == 0); } + + [SkippableTheory] + [InlineData(InsertStrategy.InsertReturn)] + [InlineData(InsertStrategy.InsertReturnAsync)] + public async Task InsertEntities_WithComplexType_UpdateAll(InsertStrategy strategy) + { + Skip.If(_context.IsProvider(ProviderType.MySql)); + // Oracle MERGE does not support returning entities + Skip.If(_context.IsProvider(ProviderType.Oracle)); + + // Arrange + var entities = new List + { + new TestEntityWithComplexType + { + TestRun = _run, + OwnedComplexType = new OwnedObject { Code = 1, Name = "Name1" } + }, + new TestEntityWithComplexType + { + TestRun = _run, + OwnedComplexType = new OwnedObject { Code = 2, Name = "Name2" } + } + }; + + // Act - First insert (without CopyGeneratedColumns - returns generated IDs via RETURNING) + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities); + + // Update the complex properties + foreach (var entity in insertedEntities) + { + entity.OwnedComplexType = new OwnedObject + { + Code = entity.OwnedComplexType.Code + 100, + Name = $"Updated_{entity.OwnedComplexType.Name}" + }; + } + + // Act - Second insert with update on conflict + // The ParameterExpression case in GetUpdates generates UPDATE statements for all columns + var updatedEntities = await _context.InsertWithStrategyAsync(strategy, insertedEntities, o => o.CopyGeneratedColumns = true, + onConflict: new OnConflictOptions + { + Update = (inserted, excluded) => inserted, + }); + + // Assert - complex properties should be updated + Assert.Equal(2, updatedEntities.Count); + Assert.All(updatedEntities, e => + { + Assert.StartsWith("Updated_", e.OwnedComplexType.Name); + Assert.True(e.OwnedComplexType.Code > 100); + }); + } + + [SkippableTheory] + [InlineData(InsertStrategy.InsertReturn)] + [InlineData(InsertStrategy.InsertReturnAsync)] + public async Task InsertEntities_WithComplexType_UpdateWithWhere(InsertStrategy strategy) + { + Skip.If(_context.IsProvider(ProviderType.MySql)); + // Oracle MERGE does not support returning entities + Skip.If(_context.IsProvider(ProviderType.Oracle)); + + // Arrange - initial Code values are 10 and 20 + var entities = new List + { + new TestEntityWithComplexType + { + TestRun = _run, + OwnedComplexType = new OwnedObject { Code = 10, Name = "Original1" } + }, + new TestEntityWithComplexType + { + TestRun = _run, + OwnedComplexType = new OwnedObject { Code = 20, Name = "Original2" } + } + }; + + // Act - First insert (without CopyGeneratedColumns - returns generated IDs via RETURNING) + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities); + + // Update the complex property - new Code values will be original + 100 (110 and 120) + foreach (var entity in insertedEntities) + { + entity.OwnedComplexType.Name = $"Changed_{entity.OwnedComplexType.Name}"; + entity.OwnedComplexType.Code = entity.OwnedComplexType.Code + 100; + } + + // Act - Second insert updating complex properties with a WHERE condition + // This tests that complex property access works correctly in the Where clause + var updatedEntities = await _context.InsertWithStrategyAsync(strategy, insertedEntities, o => o.CopyGeneratedColumns = true, + onConflict: new OnConflictOptions + { + Update = (inserted, excluded) => inserted, + Where = (inserted, excluded) => excluded.OwnedComplexType.Code > inserted.OwnedComplexType.Code + }); + + // Assert - entities should be updated because the new Code values (110, 120) + // are greater than the existing values in the database (10, 20) + Assert.Equal(2, updatedEntities.Count); + Assert.All(updatedEntities, e => + { + Assert.StartsWith("Changed_", e.OwnedComplexType.Name); + Assert.True(e.OwnedComplexType.Code > 100); + }); + } + + [SkippableTheory] + [InlineData(InsertStrategy.InsertReturn)] + [InlineData(InsertStrategy.InsertReturnAsync)] + public async Task InsertEntities_WithComplexType_UpdateComplexPropertyConditionally(InsertStrategy strategy) + { + Skip.If(_context.IsProvider(ProviderType.MySql)); + // Oracle MERGE does not support returning entities + Skip.If(_context.IsProvider(ProviderType.Oracle)); + + // Arrange - Create entities with different Code values + var entities = new List + { + new TestEntityWithComplexType + { + TestRun = _run, + OwnedComplexType = new OwnedObject { Code = 50, Name = "LowCode" } + }, + new TestEntityWithComplexType + { + TestRun = _run, + OwnedComplexType = new OwnedObject { Code = 150, Name = "HighCode" } + } + }; + + // Act - First insert (returns entities with generated IDs) + var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities); + + // Update both entities with new values + foreach (var entity in insertedEntities) + { + entity.OwnedComplexType.Code = entity.OwnedComplexType.Code + 10; + entity.OwnedComplexType.Name = $"Modified_{entity.OwnedComplexType.Name}"; + } + + // Act - Update using nested MemberInitExpression for complex property assignment + // Note: entities with Code >= 100 (original value) will not be updated due to WHERE clause + var updatedEntities = await _context.InsertWithStrategyAsync(strategy, insertedEntities, + o => o.CopyGeneratedColumns = true, + onConflict: new OnConflictOptions + { + Update = (inserted, excluded) => new TestEntityWithComplexType + { + OwnedComplexType = new OwnedObject + { + Code = excluded.OwnedComplexType.Code, + Name = excluded.OwnedComplexType.Name + } + }, + Where = (inserted, excluded) => inserted.OwnedComplexType.Code < 100 + }); + + // Assert - Only the entity with original Code < 100 should be updated (Code was 50, now 60) + // The one with original Code >= 100 is not updated but is also not returned by RETURNING clause + Assert.Single(updatedEntities); + var updatedEntity = updatedEntities.Single(); + Assert.Equal(60, updatedEntity.OwnedComplexType.Code); + Assert.Equal("Modified_LowCode", updatedEntity.OwnedComplexType.Name); + } } diff --git a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Various/VariousTestsBase.cs b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Various/VariousTestsBase.cs index 748e447..ca5724e 100644 --- a/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Various/VariousTestsBase.cs +++ b/tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Various/VariousTestsBase.cs @@ -34,8 +34,8 @@ public async Task InsertSmartEnumEntities(InsertStrategy strategy) // Arrange var entities = new List { - new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Value}, - new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Value} + new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Test}, + new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Test} }; // Act