Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ protected IEnumerable<string> GetUpdates<T>(DbContext context, TableMetadata tab
}
case MemberInitExpression memberInit:
{
foreach (var binding in memberInit.Bindings.OfType<MemberAssignment>())
foreach (var updateSql in GetUpdatesFromMemberInit<T>(context, table, memberInit, lambda))
{
yield return $"{table.GetQuotedColumnName(binding.Member.Name)} = {ToSqlExpression<T>(context, table, binding.Expression, lambda)}";
yield return updateSql;
}

break;
Expand Down Expand Up @@ -297,15 +297,20 @@ private string ToSqlExpression<TEntity>(DbContext context, TableMetadata table,
case MemberExpression memberExpr:
var columnName = table.GetColumnName(memberExpr.Member.Name);

// Traverse up the expression chain to find the root parameter
// This handles both simple properties (e.g., excluded.Name) and
// complex properties (e.g., excluded.ComplexObject.Property)
var rootParam = GetRootParameter(memberExpr);

// If the member expression is a property of the current lambda
if (lambda is { Parameters.Count: > 1 } && memberExpr.Expression is ParameterExpression paramExpr)
if (lambda is { Parameters.Count: > 1 } && rootParam != null)
{
if (paramExpr.Name == lambda.Parameters[0].Name)
if (rootParam.Name == lambda.Parameters[0].Name)
{
return GetInsertedColumnName(columnName);
}

if (paramExpr.Name == lambda.Parameters[1].Name)
if (rootParam.Name == lambda.Parameters[1].Name)
{
return GetExcludedColumnName(columnName);
}
Expand Down Expand Up @@ -405,4 +410,64 @@ private string ToSqlExpression<TEntity>(DbContext context, TableMetadata table,
throw new NotSupportedException($"Expression not supported: {expr.NodeType}");
}
}

/// <summary>
/// Extracts update SQL statements from a MemberInitExpression, handling both simple properties
/// and nested complex property initializations recursively.
/// </summary>
/// <param name="context">DB context</param>
/// <param name="table">Table metadata</param>
/// <param name="memberInit">The member initialization expression</param>
/// <param name="lambda">Current lambda expression</param>
/// <typeparam name="T">Entity type</typeparam>
/// <returns>SQL update statements for each property assignment</returns>
private IEnumerable<string> GetUpdatesFromMemberInit<T>(DbContext context, TableMetadata table, MemberInitExpression memberInit, LambdaExpression lambda)
{
foreach (var binding in memberInit.Bindings.OfType<MemberAssignment>())
{
// Check if the binding expression is a nested MemberInitExpression (complex property assignment)
if (binding.Expression is MemberInitExpression nestedMemberInit)
{
// Recursively process nested complex property assignments to handle arbitrary nesting levels
foreach (var update in GetUpdatesFromMemberInit<T>(context, table, nestedMemberInit, lambda))
{
yield return update;
}
}
else
{
// Simple property assignment - the column name is the property name
yield return $"{table.GetQuotedColumnName(binding.Member.Name)} = {ToSqlExpression<T>(context, table, binding.Expression, lambda)}";
}
}
}

/// <summary>
/// Traverses up a member expression chain to find the root parameter expression.
/// This handles both simple properties (e.g., excluded.Name) and complex properties (e.g., excluded.ComplexObject.Property).
/// </summary>
/// <param name="memberExpr">The member expression to traverse.</param>
/// <returns>The root parameter expression if found; otherwise, null if the expression chain doesn't contain a parameter.</returns>
private static ParameterExpression? GetRootParameter(MemberExpression memberExpr)
{
Expression? current = memberExpr.Expression;
while (current != null)
{
if (current is ParameterExpression param)
{
return param;
}

if (current is MemberExpression nested)
{
current = nested.Expression;
}
else
{
break;
}
}

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ private TestSmartEnum(string name, int value) : base(name, value)
{
}

public static readonly TestSmartEnum Value = new TestSmartEnum("test", 1);
public static readonly TestSmartEnum Test = new TestSmartEnum("test", 1);
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,170 @@ public async Task InsertEntities_WithConflict_MultipleColumns(InsertStrategy str
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity2");
Assert.Contains(insertedEntities, e => e.Name == $"{_run}_Entity1 - Conflict" && e.Price == 0);
}

[SkippableTheory]
[InlineData(InsertStrategy.InsertReturn)]
[InlineData(InsertStrategy.InsertReturnAsync)]
public async Task InsertEntities_WithComplexType_UpdateAll(InsertStrategy strategy)
{
Skip.If(_context.IsProvider(ProviderType.MySql));
// Oracle MERGE does not support returning entities
Skip.If(_context.IsProvider(ProviderType.Oracle));

// Arrange
var entities = new List<TestEntityWithComplexType>
{
new TestEntityWithComplexType
{
TestRun = _run,
OwnedComplexType = new OwnedObject { Code = 1, Name = "Name1" }
},
new TestEntityWithComplexType
{
TestRun = _run,
OwnedComplexType = new OwnedObject { Code = 2, Name = "Name2" }
}
};

// Act - First insert (without CopyGeneratedColumns - returns generated IDs via RETURNING)
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities);

// Update the complex properties
foreach (var entity in insertedEntities)
{
entity.OwnedComplexType = new OwnedObject
{
Code = entity.OwnedComplexType.Code + 100,
Name = $"Updated_{entity.OwnedComplexType.Name}"
};
}

// Act - Second insert with update on conflict
// The ParameterExpression case in GetUpdates generates UPDATE statements for all columns
var updatedEntities = await _context.InsertWithStrategyAsync(strategy, insertedEntities, o => o.CopyGeneratedColumns = true,
onConflict: new OnConflictOptions<TestEntityWithComplexType>
{
Update = (inserted, excluded) => inserted,
});

// Assert - complex properties should be updated
Assert.Equal(2, updatedEntities.Count);
Assert.All(updatedEntities, e =>
{
Assert.StartsWith("Updated_", e.OwnedComplexType.Name);
Assert.True(e.OwnedComplexType.Code > 100);
});
}

[SkippableTheory]
[InlineData(InsertStrategy.InsertReturn)]
[InlineData(InsertStrategy.InsertReturnAsync)]
public async Task InsertEntities_WithComplexType_UpdateWithWhere(InsertStrategy strategy)
{
Skip.If(_context.IsProvider(ProviderType.MySql));
// Oracle MERGE does not support returning entities
Skip.If(_context.IsProvider(ProviderType.Oracle));

// Arrange - initial Code values are 10 and 20
var entities = new List<TestEntityWithComplexType>
{
new TestEntityWithComplexType
{
TestRun = _run,
OwnedComplexType = new OwnedObject { Code = 10, Name = "Original1" }
},
new TestEntityWithComplexType
{
TestRun = _run,
OwnedComplexType = new OwnedObject { Code = 20, Name = "Original2" }
}
};

// Act - First insert (without CopyGeneratedColumns - returns generated IDs via RETURNING)
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities);

// Update the complex property - new Code values will be original + 100 (110 and 120)
foreach (var entity in insertedEntities)
{
entity.OwnedComplexType.Name = $"Changed_{entity.OwnedComplexType.Name}";
entity.OwnedComplexType.Code = entity.OwnedComplexType.Code + 100;
}

// Act - Second insert updating complex properties with a WHERE condition
// This tests that complex property access works correctly in the Where clause
var updatedEntities = await _context.InsertWithStrategyAsync(strategy, insertedEntities, o => o.CopyGeneratedColumns = true,
onConflict: new OnConflictOptions<TestEntityWithComplexType>
{
Update = (inserted, excluded) => inserted,
Where = (inserted, excluded) => excluded.OwnedComplexType.Code > inserted.OwnedComplexType.Code
});

// Assert - entities should be updated because the new Code values (110, 120)
// are greater than the existing values in the database (10, 20)
Assert.Equal(2, updatedEntities.Count);
Assert.All(updatedEntities, e =>
{
Assert.StartsWith("Changed_", e.OwnedComplexType.Name);
Assert.True(e.OwnedComplexType.Code > 100);
});
}

[SkippableTheory]
[InlineData(InsertStrategy.InsertReturn)]
[InlineData(InsertStrategy.InsertReturnAsync)]
public async Task InsertEntities_WithComplexType_UpdateComplexPropertyConditionally(InsertStrategy strategy)
{
Skip.If(_context.IsProvider(ProviderType.MySql));
// Oracle MERGE does not support returning entities
Skip.If(_context.IsProvider(ProviderType.Oracle));

// Arrange - Create entities with different Code values
var entities = new List<TestEntityWithComplexType>
{
new TestEntityWithComplexType
{
TestRun = _run,
OwnedComplexType = new OwnedObject { Code = 50, Name = "LowCode" }
},
new TestEntityWithComplexType
{
TestRun = _run,
OwnedComplexType = new OwnedObject { Code = 150, Name = "HighCode" }
}
};

// Act - First insert (returns entities with generated IDs)
var insertedEntities = await _context.InsertWithStrategyAsync(strategy, entities);

// Update both entities with new values
foreach (var entity in insertedEntities)
{
entity.OwnedComplexType.Code = entity.OwnedComplexType.Code + 10;
entity.OwnedComplexType.Name = $"Modified_{entity.OwnedComplexType.Name}";
}

// Act - Update using nested MemberInitExpression for complex property assignment
// Note: entities with Code >= 100 (original value) will not be updated due to WHERE clause
var updatedEntities = await _context.InsertWithStrategyAsync(strategy, insertedEntities,
o => o.CopyGeneratedColumns = true,
onConflict: new OnConflictOptions<TestEntityWithComplexType>
{
Update = (inserted, excluded) => new TestEntityWithComplexType
{
OwnedComplexType = new OwnedObject
{
Code = excluded.OwnedComplexType.Code,
Name = excluded.OwnedComplexType.Name
}
},
Where = (inserted, excluded) => inserted.OwnedComplexType.Code < 100
});

// Assert - Only the entity with original Code < 100 should be updated (Code was 50, now 60)
// The one with original Code >= 100 is not updated but is also not returned by RETURNING clause
Assert.Single(updatedEntities);
var updatedEntity = updatedEntities.Single();
Assert.Equal(60, updatedEntity.OwnedComplexType.Code);
Assert.Equal("Modified_LowCode", updatedEntity.OwnedComplexType.Name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ public async Task InsertSmartEnumEntities(InsertStrategy strategy)
// Arrange
var entities = new List<TestEntityWithSmartEnum>
{
new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Value},
new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Value}
new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Test},
new TestEntityWithSmartEnum { TestRun = _run, Enum = TestSmartEnum.Test}
};

// Act
Expand Down
Loading