@@ -75,6 +75,9 @@ public async Task<GraphInsertResult<T>> InsertGraph<T>(
7575
7676 var connection = await _context . GetConnection ( sync , ctk ) ;
7777
78+ // Track original primary key values for rollback
79+ var originalPkValues = new Dictionary < object , Dictionary < string , object ? > > ( ) ;
80+
7881 try
7982 {
8083 // 2. Insert in dependency order (parents first)
@@ -86,6 +89,12 @@ public async Task<GraphInsertResult<T>> InsertGraph<T>(
8689 continue ;
8790 }
8891
92+ // Save original PK values before any modifications
93+ if ( options . RestoreOriginalPrimaryKeysOnGraphInsertFailure )
94+ {
95+ SaveOriginalPrimaryKeyValues ( entitiesToInsert , entityType , graphMetadata , originalPkValues ) ;
96+ }
97+
8998 // Propagate FK values from already-inserted parents
9099 PropagateParentForeignKeys ( entitiesToInsert , entityType , graphMetadata ) ;
91100
@@ -117,6 +126,16 @@ await InsertEntitiesOfType(sync, _context, entityType, entitiesToInsert, options
117126 TotalInsertedCount = totalInserted ,
118127 } ;
119128 }
129+ catch
130+ {
131+ // Restore original PK values on rollback
132+ if ( options . RestoreOriginalPrimaryKeysOnGraphInsertFailure )
133+ {
134+ RestoreOriginalPrimaryKeyValues ( originalPkValues , graphMetadata ) ;
135+ }
136+
137+ throw ;
138+ }
120139 finally
121140 {
122141 await connection . Close ( sync , ctk ) ;
@@ -436,14 +455,14 @@ private async Task InsertJoinEntities(
436455
437456 // Use reflection to call the generic BulkInsert method with correctly typed entities
438457 var method = typeof ( GraphBulkInsertOrchestrator )
439- . GetMethod ( nameof ( InsertJoinEntitiesGenericAsync ) , BindingFlags . NonPublic | BindingFlags . Instance ) !
458+ . GetMethod ( nameof ( InsertJoinEntitiesGeneric ) , BindingFlags . NonPublic | BindingFlags . Instance ) !
440459 . MakeGenericMethod ( joinEntityType ) ;
441460
442461 var task = ( Task ) method . Invoke ( this , [ sync , context , tableInfo , joinEntities , options , provider , ctk ] ) ! ;
443462 await task ;
444463 }
445464
446- private async Task InsertJoinEntitiesGenericAsync < TJoin > (
465+ private static async Task InsertJoinEntitiesGeneric < TJoin > (
447466 bool sync ,
448467 DbContext context ,
449468 TableMetadata tableInfo ,
@@ -456,4 +475,66 @@ private async Task InsertJoinEntitiesGenericAsync<TJoin>(
456475 var typedEntities = joinEntities . Cast < TJoin > ( ) . ToList ( ) ;
457476 await provider . BulkInsert ( sync , context , tableInfo , typedEntities , options , null , ctk ) ;
458477 }
478+
479+ private static void SaveOriginalPrimaryKeyValues (
480+ List < object > entities ,
481+ Type entityType ,
482+ GraphMetadata graphMetadata ,
483+ Dictionary < object , Dictionary < string , object ? > > originalPkValues )
484+ {
485+ var entityMetadata = graphMetadata . GetEntityMetadata ( entityType ) ;
486+ if ( entityMetadata == null )
487+ {
488+ return ;
489+ }
490+
491+ var efEntityType = graphMetadata . GetEntityType ( entityType ) ;
492+
493+ var pkProperties = efEntityType ? . FindPrimaryKey ( ) ? . Properties ;
494+ if ( pkProperties == null || ! pkProperties . Any ( ) )
495+ {
496+ return ;
497+ }
498+
499+ // Only save values for database-generated keys
500+ var generatedPkProps = pkProperties
501+ . Where ( p => p . ValueGenerated != Microsoft . EntityFrameworkCore . Metadata . ValueGenerated . Never )
502+ . ToList ( ) ;
503+
504+ if ( generatedPkProps . Count == 0 )
505+ {
506+ return ;
507+ }
508+
509+ foreach ( var entity in entities )
510+ {
511+ var pkValues = new Dictionary < string , object ? > ( ) ;
512+ foreach ( var pkProp in generatedPkProps )
513+ {
514+ var value = entityMetadata . GetPropertyValue ( entity , pkProp . Name ) ;
515+ pkValues [ pkProp . Name ] = value ;
516+ }
517+ originalPkValues [ entity ] = pkValues ;
518+ }
519+ }
520+
521+ private static void RestoreOriginalPrimaryKeyValues (
522+ Dictionary < object , Dictionary < string , object ? > > originalPkValues ,
523+ GraphMetadata graphMetadata )
524+ {
525+ foreach ( var ( entity , pkValues ) in originalPkValues )
526+ {
527+ var entityType = entity . GetType ( ) ;
528+ var entityMetadata = graphMetadata . GetEntityMetadata ( entityType ) ;
529+ if ( entityMetadata == null )
530+ {
531+ continue ;
532+ }
533+
534+ foreach ( var ( propertyName , originalValue ) in pkValues )
535+ {
536+ entityMetadata . SetPropertyValue ( entity , propertyName , originalValue ) ;
537+ }
538+ }
539+ }
459540}
0 commit comments