Skip to content

Fix rolled prompts for contrastive disc in ARC training#246

Open
cpvlordelo wants to merge 4 commits into
Stability-AI:mainfrom
cpvlordelo:fix/arc-rolled-metadata
Open

Fix rolled prompts for contrastive disc in ARC training#246
cpvlordelo wants to merge 4 commits into
Stability-AI:mainfrom
cpvlordelo:fix/arc-rolled-metadata

Conversation

@cpvlordelo
Copy link
Copy Markdown

The fix uses dict-unpacking to build fresh dicts — the current training_step pattern appends references to the original metadata dicts and mutates them in place, which (A) leaks side effects to subsequent code paths that still hold references to metadata, and (b) makes the very last iteration read metadata[0]["prompt"] after iteration 0 already overwrote it, so the last element of the rolled batch ends up with the wrong prompt.

Problem (A):

Not affecting anywhere today since metadata[i] is not used downstream inside training_step, so even though we are changing it in place, this is not affecting anything in downstream code. In the future, if any code uses metadata[i] after the in-place modification is done, they would also receive rolled prompts, which we don't want. The fix will prevent this from happening.

Problem (B)

Walkthrough with prompts [A, B, C, D] (n=4)

The intended rolled_metadata should be [B, C, D, A] — each item gets the
next item's prompt, wrapping at the end.

iter i reads metadata[(i+1) % 4]["prompt"] writes back to metadata[i]["prompt"] state of metadata prompts after this iteration
0 metadata[1]B (original) metadata[0]B [B, B, C, D]
1 metadata[2]C (original) metadata[1]C [B, C, C, D]
2 metadata[3]D (original) metadata[2]D [B, C, D, D]
3 metadata[0]Balready mutated in iter 0! Should have been A. metadata[3]B [B, C, D, B]

rolled_metadata ends up as [B, C, D, B] instead of [B, C, D, A] — the
last element is wrong.

Where the bug lives

For any n ≥ 2, only the last iteration reads a value that an earlier iteration already overwrote. That's because (n-1 + 1) % n == 0, and index 0 is the one iteration 0 wrote to. All earlier reads (metadata[i+1] for i = 0..n-2) hit indices that haven't been touched yet.

There's also a degenerate case: with n = 2 ([A, B]), iteration 0 sets metadata[0] ← B, iteration 1 reads metadata[0] (already B) and writes metadata[1] ← B rolled is [B, B]. The "wrong" element is still the last one, but it's a more visible 50%-of-the-batch corruption in that case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant