Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions swift/arguments/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Dict, List, Literal, Optional, Union

import swift
from swift.dataset import load_dataset
from swift.hub import get_hub
from swift.model import get_ckpt_dir, get_model_processor, load_by_unsloth
from swift.ray_utils import RayArguments
Expand Down Expand Up @@ -342,3 +343,20 @@ def get_model_processor(self,
res['num_labels'] = num_labels or self.num_labels

return get_model_processor(**res)

def load_dataset(self):
dataset_kwargs = self.get_dataset_kwargs()
train_dataset, val_dataset = None, None
if self.dataset:
train_dataset, val_dataset = load_dataset(
self.dataset,
split_dataset_ratio=self.split_dataset_ratio,
shuffle=self.dataset_shuffle,
**dataset_kwargs)
if len(self.val_dataset) > 0:
# Loading val dataset
dataset_kwargs.pop('interleave_prob', None)
_, val_dataset = load_dataset(
self.val_dataset, split_dataset_ratio=1.0, shuffle=self.val_dataset_shuffle, **dataset_kwargs)
assert self.split_dataset_ratio == 0.
return train_dataset, val_dataset
17 changes: 2 additions & 15 deletions swift/pipelines/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from swift.arguments import SftArguments
from swift.dataset import (AddLengthPreprocessor, DatasetLoader, EncodePreprocessor, IterablePackingDataset,
LazyLLMDataset, PackingDataset, load_dataset)
LazyLLMDataset, PackingDataset)
from swift.infer_engine import prepare_generation_config
from swift.ray_utils import RayHelper
from swift.sequence_parallel import sequence_parallel
Expand Down Expand Up @@ -78,20 +78,7 @@ def _prepare_template(self) -> None:
def _get_dataset(self):
# The random shuffling of the training set occurs in the dataloader of the trainer.
args = self.args
dataset_kwargs = args.get_dataset_kwargs()
train_dataset, val_dataset = None, None
if args.dataset:
train_dataset, val_dataset = load_dataset(
args.dataset,
split_dataset_ratio=args.split_dataset_ratio,
shuffle=args.dataset_shuffle,
**dataset_kwargs)
if len(args.val_dataset) > 0:
# Loading val dataset
dataset_kwargs.pop('interleave_prob', None)
_, val_dataset = load_dataset(
args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs)
assert args.split_dataset_ratio == 0.
train_dataset, val_dataset = args.load_dataset()
if args.truncation_strategy == 'split':
logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
Expand Down
60 changes: 8 additions & 52 deletions swift/ray/megatron/driver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
* extracting the canonical iteration from worker results
"""
import json
import os
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from swift.utils.logger import get_logger
from swift.arguments import BaseArguments
from swift.utils import get_logger, parse_args, seed_everything, to_abspath

logger = get_logger()

Expand All @@ -29,8 +31,6 @@

def parse_args_from_dict(class_type, cfg: Dict[str, Any]):
"""Construct a dataclass from a config dict via HfArgumentParser."""
from swift.utils import parse_args

argv = _dict_to_argv(cfg)
args, remaining_args = parse_args(class_type, argv)
if remaining_args:
Expand Down Expand Up @@ -177,12 +177,8 @@ def build_dataset_from_dict(cfg: Dict[str, Any]):
``get_template``, ``get_dataset_kwargs``) so no distributed init
or Megatron-specific ``__post_init__`` logic is triggered.
"""
import os

from swift.megatron.arguments import MegatronRLHFArguments
from swift.rlhf_trainers.utils import identity_data_collator
from swift.utils import seed_everything, to_abspath

cfg = dict(cfg)
cfg['skip_megatron_init'] = True
args = parse_args_from_dict(MegatronRLHFArguments, cfg)
Expand All @@ -202,7 +198,7 @@ def build_dataset_from_dict(cfg: Dict[str, Any]):
_, processor = args.get_model_processor(load_model=False, download_model=args.mcore_model is None)

template = _prepare_template(args, processor)
train_dataset, val_dataset = _prepare_dataset(args, template)
train_dataset, val_dataset = _prepare_dataset(args)
Comment thread
Jintao-Huang marked this conversation as resolved.

data_collator = identity_data_collator if rlhf_type in ('grpo', 'gkd') else template.data_collator

Expand Down Expand Up @@ -233,10 +229,8 @@ def _prepare_template(args, processor):
return template


def _prepare_dataset(args, template):
def _prepare_dataset(args: BaseArguments):
Comment thread
Jintao-Huang marked this conversation as resolved.
"""Load and optionally encode dataset — no pipeline object needed."""
from swift.dataset import DatasetLoader, load_dataset

# Ray pipeline has no validation/eval loop yet
if args.split_dataset_ratio and args.split_dataset_ratio > 0:
logger.warning(
Expand All @@ -247,46 +241,8 @@ def _prepare_dataset(args, template):
logger.warning('Ray pipeline has no validation loop yet; ignoring val_dataset=%s.', args.val_dataset)
args.val_dataset = []

pre_process = args.rlhf_type not in ('grpo', 'gkd')
train_datasets, val_datasets = [], []

if args.dataset or args.val_dataset:
dataset_kwargs = args.get_dataset_kwargs()
train_dataset, val_dataset = None, None
if args.dataset:
train_dataset, val_dataset = load_dataset(
args.dataset,
split_dataset_ratio=args.split_dataset_ratio,
shuffle=args.dataset_shuffle,
**dataset_kwargs)
if args.val_dataset:
_, val_dataset = load_dataset(
args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs)

if not pre_process:
return train_dataset, val_dataset

from swift.dataset import AddLengthPreprocessor
for i, ds in enumerate([train_dataset, val_dataset]):
if ds is None:
continue
if not args.lazy_tokenize and not args.streaming:
preprocessor = AddLengthPreprocessor(template=template)
batch_size = 100 if args.model_meta.is_multimodal else 1000
ds = preprocessor(
ds,
num_proc=args.dataset_num_proc,
load_from_cache_file=args.load_from_cache_file,
strict=args.strict,
batch_size=batch_size)
if i == 0:
train_datasets.append(ds)
else:
val_datasets.append(ds)

train_dataset = DatasetLoader.concat_datasets(train_datasets) if train_datasets else None
val_dataset = DatasetLoader.concat_datasets(val_datasets) if val_datasets else None
return train_dataset, val_dataset
assert args.rlhf_type in ('grpo', 'gkd')
return args.load_dataset()
Comment thread
Jintao-Huang marked this conversation as resolved.


def compute_iter_params(data_info: Dict[str, Any], dp_size: int) -> Dict[str, Any]:
Expand Down
6 changes: 3 additions & 3 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@
unwrap_model_for_generation)
from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port,
find_node_ip, format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound,
parse_args, patch_getattr, read_multi_line, remove_response, retry_decorator, seed_everything,
shutdown_event_loop_in_daemon, split_list, start_event_loop_in_daemon, subprocess_run, test_time,
to_abspath, upper_bound)
parse_args, parse_args_from_dict, patch_getattr, read_multi_line, remove_response, retry_decorator,
seed_everything, shutdown_event_loop_in_daemon, split_list, start_event_loop_in_daemon,
subprocess_run, test_time, to_abspath, upper_bound)
6 changes: 6 additions & 0 deletions swift/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def parse_args(class_type: Type[_T], argv: Optional[List[str]] = None) -> Tuple[
return args, remaining_args


def parse_args_from_dict(class_type: Type[_T], args: Dict[str, Any]) -> _T:
with _patch_get_type_hints():
parser = HfArgumentParser([class_type])
return parser.parse_dict(args, allow_extra_keys=True)[0]


def lower_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
# The lower bound satisfying the condition "cond".
while lo < hi:
Expand Down
Loading