Skip to content

Add generic RSA (Recursive Self-Aggregation) algorithm framework#124

Open
shashankk42 wants to merge 13 commits into
FLASK-LLNL:mainfrom
shashankk42:feature/rsa-generic-algorithm
Open

Add generic RSA (Recursive Self-Aggregation) algorithm framework#124
shashankk42 wants to merge 13 commits into
FLASK-LLNL:mainfrom
shashankk42:feature/rsa-generic-algorithm

Conversation

@shashankk42

@shashankk42 shashankk42 commented Apr 12, 2026

Copy link
Copy Markdown

Add generic Recursive Self-Aggregation (RSA) algorithm as a Task subclass. Domain-agnostic by default; subclass to customize
for any prompt-and-aggregate problem.

Changes

  • Add charge/algorithms/rsa.py with RSATask(Task) exposing the N-K-T loop via an async run_rsa() method
  • Hook methods on RSATask for domain customization (defaults work without subclassing):
    • format_candidates(subset) : how candidates are rendered for aggregation
    • validate_proposal(result) : accept/reject a proposal
    • build_proposal_task() / build_aggregation_task(text, step) — per-stage Task construction
  • 3-part prompt structure carried on RSATask:
    • system_prompt: domain expert (proposals and aggregations)
    • proposal_prompt: instructions for generating a single solution
    • aggregation_prompt: template for synthesizing K-subsets ({original_prompt}, {candidates}, {step}, {total_steps})
  • Module-level helpers exported for one-off Task construction without the loop:
    • create_default_proposal_task(...), create_default_aggregation_task(...), default_format_candidates(...),
      GenericRSAOutput
  • Default prompts in charge/algorithms/prompts/ (proposal system, aggregation template)
  • README rewritten with a minimal example + cross-domain example (non-retrosynthesis) demonstrating that no subclassing is
    required

API

task = RSATask(n=4, k=2, t=2, user_prompt="...")
output, result = await task.run_rsa(runner, runner_factory=...)

Subclass for domain customization:
class MyDomainRSATask(RSATask):
    def format_candidates(self, subset): ...
    def validate_proposal(self, result): ...

Testing

  • N/K/T validation rejects values < 2 and K > N with messages naming the violator(s)
  • Default hooks introspect any Pydantic schema (verified with GenericRSAOutput and a custom schema)
  • Bearer token / server_urls / builtin_tools correctly propagated to every per-proposal and per-aggregation Task

Related PRs

@shashankk42 shashankk42 marked this pull request as draft April 16, 2026 22:28
@shashankk42 shashankk42 force-pushed the feature/rsa-generic-algorithm branch from 83e7a55 to cdc9fa2 Compare April 24, 2026 00:04
@shashankk42 shashankk42 marked this pull request as ready for review April 24, 2026 01:50
@shashankk42 shashankk42 marked this pull request as draft April 24, 2026 04:51
@shashankk42 shashankk42 marked this pull request as ready for review April 24, 2026 05:02
@shashankk42 shashankk42 marked this pull request as draft April 24, 2026 06:30
@shashankk42 shashankk42 marked this pull request as ready for review April 24, 2026 06:34
@bvanessen bvanessen force-pushed the feature/rsa-generic-algorithm branch from 9b73603 to ffbe801 Compare May 1, 2026 20:47
Comment thread charge/algorithms/rsa.py

@tbennun tbennun left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on the code. I think this PR needs to be refactored to the Task interface as we discussed. Maybe even have FLASK-Copilot use that interface directly with its prompts.

Comment thread charge/algorithms/rsa.py Outdated
_DEFAULT_AGGREGATION_TEMPLATE = _PROMPTS_DIR / "default_aggregation_template.txt"


# Generic default output schema

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code and all the classes seem quite excessive. Why do we need a class for RSAPrompts? If RSA extends a Task we should add those fields there?

Comment thread charge/algorithms/rsa.py Outdated
)


class RSACallbacks:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary to put in a separate class

Comment thread charge/algorithms/rsa.py Outdated
return task


class RSATaskFactories(Generic[T]):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make a generic RSA task factory? What is the need for the factory design pattern? I thought that there should just be an RSATask that extends Task

Comment thread charge/algorithms/rsa.py Outdated
k = config.k

# Helper function to run a single proposal
async def run_single_proposal(proposal_index: int, proposal_runner: Any):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long nested functions... Hard to understand the flow.

Comment thread charge/algorithms/rsa.py Outdated
await callbacks.logger_info(f"Generated {len(proposals)} valid proposals")

# Helper function to run a single aggregation
async def run_single_aggregation(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, please refactor out

@shashankk42

Copy link
Copy Markdown
Author

Pushed deacaa0 after the approval to address review feedback on the consuming PR (FLASK-LLNL/flask-copilot#194, comment from @tbennun): RSATask was extending Task without actually using its API. This adds a generic Task.run(agent, ...) to the base class and an override on RSATask that delegates to the existing run_rsa (preserved for backward compatibility). No behavior change. RSA still runs the same N-K-T loop. Flagging in case you want to take another look before merge.

@shashankk42 shashankk42 requested a review from tbennun May 23, 2026 08:18
Shashank Kushwaha and others added 13 commits May 30, 2026 03:21
Create charge/algorithms/ with RSAConfig, RSACallbacks, RSATaskFactories to make RSA reusable for any task type.
Add task-agnostic default prompts that users can swap out for domain-specific needs.
Document how to use default prompts and swap them for domain-specific needs.
Add GenericRSAOutput schema, default_format_candidates, and default task factories so RSA works immediately without customization.
Allows domain-specific prompts to be passed to aggregation tasks,
enabling clean separation between generic RSA and domain customization.
- system_prompt: Domain expert definition (constant across proposal and aggregation)
- proposal_prompt: Task instructions for generating proposals
- aggregation_prompt: Task instructions for evaluating/aggregating

This provides cleaner separation: expertise (system) vs task (user prompt).
- Add minimal working example using ChARGe only
- Document 3-part prompt structure (system, proposal, aggregation)
- Remove stale 2-part prompt references
- Clear progression from basic to advanced usage
- Remove flask-copilot dependencies from examples
- Enforce N, K, T >= 2 (T=1 would skip aggregation)
- Enforce K <= N at construction time in RSAConfig.__post_init__
- Remove redundant K validation in run_rsa_loop
- Improve error messages with actionable guidance
- Add logging when K is adjusted due to proposal failures
Task gains a default async run(agent, *args, **kwargs) that assigns
self to agent.task and dispatches agent.run, so callers can invoke any
Task through a uniform API. RSATask overrides run() to delegate to the
existing run_rsa N-K-T orchestrator, preserving the public run_rsa API
for external consumers.
@shashankk42 shashankk42 force-pushed the feature/rsa-generic-algorithm branch from deacaa0 to 63f1fd0 Compare June 1, 2026 02:32
Comment thread charge/tasks/task.py

self.constructor_args = {}

async def run(self, agent, *args, **kwargs):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bvanessen FYI, this change means we need to use Task.run() everywhere instead of agent.run(task).

Let's leave that for a follow up PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants