Skip to content

Commit fef0db3

Browse files
author
fabien.menager
committed
Restore original PKs on failure
1 parent 97e239e commit fef0db3

3 files changed

Lines changed: 129 additions & 2 deletions

File tree

src/PhenX.EntityFrameworkCore.BulkInsert/Graph/GraphBulkInsertOrchestrator.cs

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ public async Task<GraphInsertResult<T>> InsertGraph<T>(
7575

7676
var connection = await _context.GetConnection(sync, ctk);
7777

78+
// Track original primary key values for rollback
79+
var originalPkValues = new Dictionary<object, Dictionary<string, object?>>();
80+
7881
try
7982
{
8083
// 2. Insert in dependency order (parents first)
@@ -86,6 +89,12 @@ public async Task<GraphInsertResult<T>> InsertGraph<T>(
8689
continue;
8790
}
8891

92+
// Save original PK values before any modifications
93+
if (options.RestoreOriginalPrimaryKeysOnGraphInsertFailure)
94+
{
95+
SaveOriginalPrimaryKeyValues(entitiesToInsert, entityType, graphMetadata, originalPkValues);
96+
}
97+
8998
// Propagate FK values from already-inserted parents
9099
PropagateParentForeignKeys(entitiesToInsert, entityType, graphMetadata);
91100

@@ -117,6 +126,16 @@ await InsertEntitiesOfType(sync, _context, entityType, entitiesToInsert, options
117126
TotalInsertedCount = totalInserted,
118127
};
119128
}
129+
catch
130+
{
131+
// Restore original PK values on rollback
132+
if (options.RestoreOriginalPrimaryKeysOnGraphInsertFailure)
133+
{
134+
RestoreOriginalPrimaryKeyValues(originalPkValues, graphMetadata);
135+
}
136+
137+
throw;
138+
}
120139
finally
121140
{
122141
await connection.Close(sync, ctk);
@@ -436,14 +455,14 @@ private async Task InsertJoinEntities(
436455

437456
// Use reflection to call the generic BulkInsert method with correctly typed entities
438457
var method = typeof(GraphBulkInsertOrchestrator)
439-
.GetMethod(nameof(InsertJoinEntitiesGenericAsync), BindingFlags.NonPublic | BindingFlags.Instance)!
458+
.GetMethod(nameof(InsertJoinEntitiesGeneric), BindingFlags.NonPublic | BindingFlags.Instance)!
440459
.MakeGenericMethod(joinEntityType);
441460

442461
var task = (Task)method.Invoke(this, [sync, context, tableInfo, joinEntities, options, provider, ctk])!;
443462
await task;
444463
}
445464

446-
private async Task InsertJoinEntitiesGenericAsync<TJoin>(
465+
private static async Task InsertJoinEntitiesGeneric<TJoin>(
447466
bool sync,
448467
DbContext context,
449468
TableMetadata tableInfo,
@@ -456,4 +475,66 @@ private async Task InsertJoinEntitiesGenericAsync<TJoin>(
456475
var typedEntities = joinEntities.Cast<TJoin>().ToList();
457476
await provider.BulkInsert(sync, context, tableInfo, typedEntities, options, null, ctk);
458477
}
478+
479+
private static void SaveOriginalPrimaryKeyValues(
480+
List<object> entities,
481+
Type entityType,
482+
GraphMetadata graphMetadata,
483+
Dictionary<object, Dictionary<string, object?>> originalPkValues)
484+
{
485+
var entityMetadata = graphMetadata.GetEntityMetadata(entityType);
486+
if (entityMetadata == null)
487+
{
488+
return;
489+
}
490+
491+
var efEntityType = graphMetadata.GetEntityType(entityType);
492+
493+
var pkProperties = efEntityType?.FindPrimaryKey()?.Properties;
494+
if (pkProperties == null || !pkProperties.Any())
495+
{
496+
return;
497+
}
498+
499+
// Only save values for database-generated keys
500+
var generatedPkProps = pkProperties
501+
.Where(p => p.ValueGenerated != Microsoft.EntityFrameworkCore.Metadata.ValueGenerated.Never)
502+
.ToList();
503+
504+
if (generatedPkProps.Count == 0)
505+
{
506+
return;
507+
}
508+
509+
foreach (var entity in entities)
510+
{
511+
var pkValues = new Dictionary<string, object?>();
512+
foreach (var pkProp in generatedPkProps)
513+
{
514+
var value = entityMetadata.GetPropertyValue(entity, pkProp.Name);
515+
pkValues[pkProp.Name] = value;
516+
}
517+
originalPkValues[entity] = pkValues;
518+
}
519+
}
520+
521+
private static void RestoreOriginalPrimaryKeyValues(
522+
Dictionary<object, Dictionary<string, object?>> originalPkValues,
523+
GraphMetadata graphMetadata)
524+
{
525+
foreach (var (entity, pkValues) in originalPkValues)
526+
{
527+
var entityType = entity.GetType();
528+
var entityMetadata = graphMetadata.GetEntityMetadata(entityType);
529+
if (entityMetadata == null)
530+
{
531+
continue;
532+
}
533+
534+
foreach (var (propertyName, originalValue) in pkValues)
535+
{
536+
entityMetadata.SetPropertyValue(entity, propertyName, originalValue);
537+
}
538+
}
539+
}
459540
}

src/PhenX.EntityFrameworkCore.BulkInsert/Options/BulkInsertOptions.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,14 @@ public class BulkInsertOptions
9898
/// </summary>
9999
public HashSet<string>? ExcludeNavigations { get; set; }
100100

101+
/// <summary>
102+
/// When enabled, if a graph insert operation fails, the original primary key values of the entities will be restored.
103+
/// This ensures that entities in memory remain consistent with the database state after a transaction rollback.
104+
/// Can add a little overhead, so it is disabled by default. Enable this option if you need to access the primary
105+
/// key values of entities after a failed graph insert operation.
106+
/// </summary>
107+
public bool RestoreOriginalPrimaryKeysOnGraphInsertFailure { get; set; }
108+
101109
internal int GetCopyTimeoutInSeconds()
102110
{
103111
return Math.Max(0, (int)CopyTimeout.TotalSeconds);

tests/PhenX.EntityFrameworkCore.BulkInsert.Tests/Tests/Graph/GraphTestsBase.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ public async Task InsertGraph_FailureMidRun_TransactionRolledBack()
751751
var act = async () => await _context.ExecuteBulkInsertAsync(blogs, options =>
752752
{
753753
options.IncludeGraph = true;
754+
options.RestoreOriginalPrimaryKeysOnGraphInsertFailure = true; // Ensure original entities are restored on failure
754755
});
755756

756757
await act.Should().ThrowAsync<Exception>("Insert should fail due to NULL constraint violation");
@@ -768,5 +769,42 @@ public async Task InsertGraph_FailureMidRun_TransactionRolledBack()
768769
// Verify original entities do NOT have IDs populated (rollback means no database-generated values)
769770
validBlog.Id.Should().Be(0, "Valid blog should not have ID after rollback");
770771
invalidBlog.Id.Should().Be(0, "Invalid blog should not have ID after rollback");
772+
773+
// Act 2 - Fix the invalid data and retry insertion with the same entities
774+
// This verifies that entities are properly restored and can be reused after rollback
775+
invalidBlog.Posts.First().Title = $"{_run}_FixedPost";
776+
777+
// Should succeed this time
778+
await _context.ExecuteBulkInsertAsync(blogs, options =>
779+
{
780+
options.IncludeGraph = true;
781+
});
782+
783+
// Assert 2 - Verify that ALL entities are now inserted successfully
784+
var insertedBlogsAfterFix = _context.Blogs.Where(b => b.TestRun == _run).ToList();
785+
insertedBlogsAfterFix.Should().HaveCount(2, "Both blogs should be inserted after fixing the data");
786+
787+
var insertedPostsAfterFix = _context.Posts.Where(p => p.TestRun == _run).ToList();
788+
insertedPostsAfterFix.Should().HaveCount(2, "Both posts should be inserted after fixing the data");
789+
790+
var insertedSettingsAfterFix = _context.BlogSettings.Where(s => s.TestRun == _run).ToList();
791+
insertedSettingsAfterFix.Should().HaveCount(1, "Blog settings should be inserted after fixing the data");
792+
793+
// Verify that the original entity references now have IDs populated
794+
validBlog.Id.Should().BeGreaterThan(0, "Valid blog should have ID after successful insert");
795+
invalidBlog.Id.Should().BeGreaterThan(0, "Fixed blog should have ID after successful insert");
796+
validBlog.Posts.First().Id.Should().BeGreaterThan(0, "Valid post should have ID after successful insert");
797+
invalidBlog.Posts.First().Id.Should().BeGreaterThan(0, "Fixed post should have ID after successful insert");
798+
validBlog.Settings!.Id.Should().BeGreaterThan(0, "Settings should have ID after successful insert");
799+
800+
// Verify FK relationships are correct
801+
validBlog.Posts.First().BlogId.Should().Be(validBlog.Id, "Valid post FK should reference its blog");
802+
invalidBlog.Posts.First().BlogId.Should().Be(invalidBlog.Id, "Fixed post FK should reference its blog");
803+
validBlog.Settings.BlogId.Should().Be(validBlog.Id, "Settings FK should reference its blog");
804+
805+
// Verify the corrected title is in the database
806+
var fixedPostInDb = _context.Posts.FirstOrDefault(p => p.Id == invalidBlog.Posts.First().Id);
807+
fixedPostInDb.Should().NotBeNull();
808+
fixedPostInDb!.Title.Should().Be($"{_run}_FixedPost", "Fixed post should have the corrected title");
771809
}
772810
}

0 commit comments

Comments
 (0)