Add generic RSA (Recursive Self-Aggregation) algorithm framework#124
Add generic RSA (Recursive Self-Aggregation) algorithm framework#124shashankk42 wants to merge 13 commits into
Conversation
83e7a55 to
cdc9fa2
Compare
9b73603 to
ffbe801
Compare
| _DEFAULT_AGGREGATION_TEMPLATE = _PROMPTS_DIR / "default_aggregation_template.txt" | ||
|
|
||
|
|
||
| # Generic default output schema |
There was a problem hiding this comment.
This code and all the classes seem quite excessive. Why do we need a class for RSAPrompts? If RSA extends a Task we should add those fields there?
| ) | ||
|
|
||
|
|
||
| class RSACallbacks: |
There was a problem hiding this comment.
unnecessary to put in a separate class
| return task | ||
|
|
||
|
|
||
| class RSATaskFactories(Generic[T]): |
There was a problem hiding this comment.
Why make a generic RSA task factory? What is the need for the factory design pattern? I thought that there should just be an RSATask that extends Task
| k = config.k | ||
|
|
||
| # Helper function to run a single proposal | ||
| async def run_single_proposal(proposal_index: int, proposal_runner: Any): |
There was a problem hiding this comment.
Long nested functions... Hard to understand the flow.
| await callbacks.logger_info(f"Generated {len(proposals)} valid proposals") | ||
|
|
||
| # Helper function to run a single aggregation | ||
| async def run_single_aggregation( |
79c870e to
4435b50
Compare
|
Pushed deacaa0 after the approval to address review feedback on the consuming PR (FLASK-LLNL/flask-copilot#194, comment from @tbennun): RSATask was extending Task without actually using its API. This adds a generic Task.run(agent, ...) to the base class and an override on RSATask that delegates to the existing run_rsa (preserved for backward compatibility). No behavior change. RSA still runs the same N-K-T loop. Flagging in case you want to take another look before merge. |
Create charge/algorithms/ with RSAConfig, RSACallbacks, RSATaskFactories to make RSA reusable for any task type.
Add task-agnostic default prompts that users can swap out for domain-specific needs.
Document how to use default prompts and swap them for domain-specific needs.
Add GenericRSAOutput schema, default_format_candidates, and default task factories so RSA works immediately without customization.
Allows domain-specific prompts to be passed to aggregation tasks, enabling clean separation between generic RSA and domain customization.
- system_prompt: Domain expert definition (constant across proposal and aggregation) - proposal_prompt: Task instructions for generating proposals - aggregation_prompt: Task instructions for evaluating/aggregating This provides cleaner separation: expertise (system) vs task (user prompt).
- Add minimal working example using ChARGe only - Document 3-part prompt structure (system, proposal, aggregation) - Remove stale 2-part prompt references - Clear progression from basic to advanced usage - Remove flask-copilot dependencies from examples
- Enforce N, K, T >= 2 (T=1 would skip aggregation) - Enforce K <= N at construction time in RSAConfig.__post_init__ - Remove redundant K validation in run_rsa_loop - Improve error messages with actionable guidance - Add logging when K is adjusted due to proposal failures
Task gains a default async run(agent, *args, **kwargs) that assigns self to agent.task and dispatches agent.run, so callers can invoke any Task through a uniform API. RSATask overrides run() to delegate to the existing run_rsa N-K-T orchestrator, preserving the public run_rsa API for external consumers.
deacaa0 to
63f1fd0
Compare
|
|
||
| self.constructor_args = {} | ||
|
|
||
| async def run(self, agent, *args, **kwargs): |
There was a problem hiding this comment.
@bvanessen FYI, this change means we need to use Task.run() everywhere instead of agent.run(task).
Let's leave that for a follow up PR
Add generic Recursive Self-Aggregation (RSA) algorithm as a
Tasksubclass. Domain-agnostic by default; subclass to customizefor any prompt-and-aggregate problem.
Changes
charge/algorithms/rsa.pywithRSATask(Task)exposing the N-K-T loop via an asyncrun_rsa()methodRSATaskfor domain customization (defaults work without subclassing):format_candidates(subset): how candidates are rendered for aggregationvalidate_proposal(result): accept/reject a proposalbuild_proposal_task()/build_aggregation_task(text, step)— per-stage Task constructionRSATask:system_prompt: domain expert (proposals and aggregations)proposal_prompt: instructions for generating a single solutionaggregation_prompt: template for synthesizing K-subsets ({original_prompt},{candidates},{step},{total_steps})create_default_proposal_task(...),create_default_aggregation_task(...),default_format_candidates(...),GenericRSAOutputcharge/algorithms/prompts/(proposal system, aggregation template)required
API
Testing
Related PRs