Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion openadapt_ml/cloud/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ def get_current_output_dir() -> Path:
return TRAINING_OUTPUT


def resolve_config_path(config: str | Path) -> Path:
"""Resolve a training config path, falling back to packaged configs.

Relative paths like "configs/qwen3vl_capture.yaml" only exist when
running from a repo checkout. When openadapt-ml is pip-installed, the
bundled copies under openadapt_ml/configs/ are used instead.
"""
path = Path(config)
if path.exists():
return path
packaged = Path(__file__).resolve().parent.parent / "configs" / path.name
if packaged.exists():
return packaged
return path


def _regenerate_viewer_if_possible(output_dir: Path) -> bool:
"""Regenerate viewer.html if comparison data exists.

Expand Down Expand Up @@ -324,7 +340,7 @@ def cmd_train(args: argparse.Namespace) -> int:
else:
config = "configs/qwen3vl_capture_4bit.yaml"

config_path = Path(config)
config_path = resolve_config_path(config)
if not config_path.exists():
print(f"Error: Config not found: {config_path}")
return 1
Expand Down
7 changes: 3 additions & 4 deletions openadapt_ml/scripts/demo_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse

from openadapt_ml.datasets.next_action import build_next_action_sft_samples
from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
from openadapt_ml.ingest.synthetic import generate_synthetic_episodes
from openadapt_ml.models.dummy_adapter import DummyAdapter
from openadapt_ml.models.qwen_vl import QwenVLAdapter
from openadapt_ml.models.api_adapter import ApiVLMAdapter
Expand All @@ -20,10 +20,9 @@ def main() -> None:
args = parser.parse_args()

# Use synthetic data to build one SFT-style sample
sessions = generate_synthetic_sessions(
num_sessions=1, seed=99, output_dir="synthetic/demo"
episodes = generate_synthetic_episodes(
num_episodes=1, seed=99, output_dir="synthetic/demo"
)
episodes = [ep for sess in sessions for ep in sess.episodes]
samples = build_next_action_sft_samples(episodes)

# Load first sample and overwrite assistant content so the dummy adapter
Expand Down
2 changes: 1 addition & 1 deletion openadapt_ml/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _load_capture_episodes(capture_path: str | Path, goal: str | None = None) ->
from openadapt_ml.ingest.capture import capture_to_episode

capture_path = Path(capture_path)
episode = capture_to_episode(capture_path, goal=goal)
episode = capture_to_episode(capture_path, instruction=goal)
return [episode]


Expand Down
3 changes: 1 addition & 2 deletions openadapt_ml/training/grpo/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def evaluate_milestones_screenshot(

# Only evaluate screenshot-type milestones locally
screenshot_milestones = [
ms for ms in milestones
if getattr(ms.check, "check", None) == "screenshot"
ms for ms in milestones if getattr(ms.check, "check", None) == "screenshot"
]
if not screenshot_milestones:
return 0.0
Expand Down
5 changes: 2 additions & 3 deletions openadapt_ml/training/grpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class BenchmarkAction: # type: ignore[no-redef]
if json_match:
try:
import json as _json

action_data = _json.loads(json_match.group())
atype = action_data.get("action_type", "").lower()
coord = action_data.get("coordinate", action_data.get("coords", []))
Expand All @@ -205,9 +206,7 @@ class BenchmarkAction: # type: ignore[no-redef]
x_val, y_val = x_val * width, y_val * height
return BenchmarkAction(type="click", x=int(x_val), y=int(y_val))
if atype == "type":
return BenchmarkAction(
type="type", text=action_data.get("text", "")
)
return BenchmarkAction(type="type", text=action_data.get("text", ""))
if atype in ("done", "wait"):
return BenchmarkAction(type=atype)
except Exception:
Expand Down
53 changes: 53 additions & 0 deletions openadapt_ml/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,59 @@ def get_current_job_directory(base_dir: str | Path) -> Path | None:
return None


def update_current_symlink_to_latest(
base_dir: str | Path = "training_output",
) -> Path | None:
"""Point the 'current' symlink at the most recent job directory.

Scans base_dir for job directories (any real subdirectory other than
the 'current' symlink itself) and atomically updates the symlink to
the most recently modified one.

Args:
base_dir: Base output directory containing job directories.

Returns:
Path to the latest job directory, or None if none exist.
"""
base_dir = Path(base_dir)
if not base_dir.is_dir():
return None

job_dirs = [
d
for d in base_dir.iterdir()
if d.is_dir() and not d.is_symlink() and not d.name.startswith(".")
]
if not job_dirs:
return None

# Prefer directories that look like training runs over stray dirs
# (e.g. a top-level "checkpoints" directory from the old flat layout).
run_like = [
d
for d in job_dirs
if (d / "training_log.json").exists() or (d / "dashboard.html").exists()
]
candidates = run_like or job_dirs

latest = max(candidates, key=lambda d: d.stat().st_mtime)

current_link = base_dir / "current"
temp_link = base_dir / f".current_temp_{latest.name}"
try:
if temp_link.exists() or temp_link.is_symlink():
temp_link.unlink()
temp_link.symlink_to(latest.name)
temp_link.rename(current_link)
except Exception as e:
if temp_link.exists() or temp_link.is_symlink():
temp_link.unlink()
raise RuntimeError(f"Failed to update current symlink: {e}")

return latest


@dataclass
class TrainingConfig:
# Model / LoRA-related fields are handled elsewhere; this covers loop hyperparams.
Expand Down
10 changes: 10 additions & 0 deletions openadapt_ml/training/trl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,13 @@ def _run_sft_training(
callbacks=[callback],
)
else:
import torch

has_cuda = torch.cuda.is_available()
has_mps = (
getattr(torch.backends, "mps", None) is not None
and torch.backends.mps.is_available()
)
training_args = SFTConfig(
output_dir=config.output_dir,
per_device_train_batch_size=config.batch_size,
Expand All @@ -423,6 +430,9 @@ def _run_sft_training(
save_strategy=config.save_strategy,
max_length=None, # Critical for VLMs
assistant_only_loss=False, # Not supported for VL models yet
use_cpu=not (has_cuda or has_mps),
bf16=has_cuda and torch.cuda.is_bf16_supported(),
fp16=False,
)

trainer = SFTTrainer(
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["openadapt_ml"]

[tool.hatch.build.targets.wheel.force-include]
"configs" = "openadapt_ml/configs"

[tool.uv.sources]
openadapt-capture = { path = "../openadapt-capture", editable = true }

Expand Down
Loading