From 5b965d6212c4f7cdf37b444f102d4fa504b2d734 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 12 Jun 2026 11:34:06 +0800 Subject: [PATCH 1/3] update dataset --- swift/arguments/base_args/base_args.py | 18 ++++++++++++++++++ swift/pipelines/train/sft.py | 17 ++--------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/swift/arguments/base_args/base_args.py b/swift/arguments/base_args/base_args.py index 423b400587..158726e3cb 100644 --- a/swift/arguments/base_args/base_args.py +++ b/swift/arguments/base_args/base_args.py @@ -12,6 +12,7 @@ from swift.model import get_ckpt_dir, get_model_processor, load_by_unsloth from swift.ray_utils import RayArguments from swift.template import Template, get_template +from swift.dataset import load_dataset from swift.tuner_plugin import tuners_map from swift.utils import (Processor, check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master, json_parse_to_dict, safe_snapshot_download, set_device, use_hf_hub) @@ -342,3 +343,20 @@ def get_model_processor(self, res['num_labels'] = num_labels or self.num_labels return get_model_processor(**res) + + def load_dataset(self): + dataset_kwargs = self.get_dataset_kwargs() + train_dataset, val_dataset = None, None + if self.dataset: + train_dataset, val_dataset = load_dataset( + self.dataset, + split_dataset_ratio=self.split_dataset_ratio, + shuffle=self.dataset_shuffle, + **dataset_kwargs) + if len(self.val_dataset) > 0: + # Loading val dataset + dataset_kwargs.pop('interleave_prob', None) + _, val_dataset = load_dataset( + self.val_dataset, split_dataset_ratio=1.0, shuffle=self.val_dataset_shuffle, **dataset_kwargs) + assert self.split_dataset_ratio == 0. + return train_dataset, val_dataset diff --git a/swift/pipelines/train/sft.py b/swift/pipelines/train/sft.py index 35c3d407ff..ef8e1f75c7 100644 --- a/swift/pipelines/train/sft.py +++ b/swift/pipelines/train/sft.py @@ -5,7 +5,7 @@ from swift.arguments import SftArguments from swift.dataset import (AddLengthPreprocessor, DatasetLoader, EncodePreprocessor, IterablePackingDataset, - LazyLLMDataset, PackingDataset, load_dataset) + LazyLLMDataset, PackingDataset) from swift.infer_engine import prepare_generation_config from swift.ray_utils import RayHelper from swift.sequence_parallel import sequence_parallel @@ -78,20 +78,7 @@ def _prepare_template(self) -> None: def _get_dataset(self): # The random shuffling of the training set occurs in the dataloader of the trainer. args = self.args - dataset_kwargs = args.get_dataset_kwargs() - train_dataset, val_dataset = None, None - if args.dataset: - train_dataset, val_dataset = load_dataset( - args.dataset, - split_dataset_ratio=args.split_dataset_ratio, - shuffle=args.dataset_shuffle, - **dataset_kwargs) - if len(args.val_dataset) > 0: - # Loading val dataset - dataset_kwargs.pop('interleave_prob', None) - _, val_dataset = load_dataset( - args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs) - assert args.split_dataset_ratio == 0. + train_dataset, val_dataset = args.load_dataset() if args.truncation_strategy == 'split': logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') From 4a9c3053a5edb0656cc3509c134a1f153a27e772 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 12 Jun 2026 11:39:35 +0800 Subject: [PATCH 2/3] update --- swift/arguments/base_args/base_args.py | 2 +- swift/ray/megatron/driver_utils.py | 57 ++++---------------------- 2 files changed, 8 insertions(+), 51 deletions(-) diff --git a/swift/arguments/base_args/base_args.py b/swift/arguments/base_args/base_args.py index 158726e3cb..ee430dcea4 100644 --- a/swift/arguments/base_args/base_args.py +++ b/swift/arguments/base_args/base_args.py @@ -8,11 +8,11 @@ from typing import Any, Dict, List, Literal, Optional, Union import swift +from swift.dataset import load_dataset from swift.hub import get_hub from swift.model import get_ckpt_dir, get_model_processor, load_by_unsloth from swift.ray_utils import RayArguments from swift.template import Template, get_template -from swift.dataset import load_dataset from swift.tuner_plugin import tuners_map from swift.utils import (Processor, check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master, json_parse_to_dict, safe_snapshot_download, set_device, use_hf_hub) diff --git a/swift/ray/megatron/driver_utils.py b/swift/ray/megatron/driver_utils.py index b82369fba7..a7f5ddc89a 100644 --- a/swift/ray/megatron/driver_utils.py +++ b/swift/ray/megatron/driver_utils.py @@ -12,9 +12,10 @@ import json import torch from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -from swift.utils.logger import get_logger +from swift.arguments import BaseArguments +from swift.utils import get_logger, parse_args, seed_everything, to_abspath logger = get_logger() @@ -29,8 +30,6 @@ def parse_args_from_dict(class_type, cfg: Dict[str, Any]): """Construct a dataclass from a config dict via HfArgumentParser.""" - from swift.utils import parse_args - argv = _dict_to_argv(cfg) args, remaining_args = parse_args(class_type, argv) if remaining_args: @@ -181,8 +180,6 @@ def build_dataset_from_dict(cfg: Dict[str, Any]): from swift.megatron.arguments import MegatronRLHFArguments from swift.rlhf_trainers.utils import identity_data_collator - from swift.utils import seed_everything, to_abspath - cfg = dict(cfg) cfg['skip_megatron_init'] = True args = parse_args_from_dict(MegatronRLHFArguments, cfg) @@ -202,7 +199,7 @@ def build_dataset_from_dict(cfg: Dict[str, Any]): _, processor = args.get_model_processor(load_model=False, download_model=args.mcore_model is None) template = _prepare_template(args, processor) - train_dataset, val_dataset = _prepare_dataset(args, template) + train_dataset, val_dataset = _prepare_dataset(args) data_collator = identity_data_collator if rlhf_type in ('grpo', 'gkd') else template.data_collator @@ -233,10 +230,8 @@ def _prepare_template(args, processor): return template -def _prepare_dataset(args, template): +def _prepare_dataset(args: BaseArguments): """Load and optionally encode dataset — no pipeline object needed.""" - from swift.dataset import DatasetLoader, load_dataset - # Ray pipeline has no validation/eval loop yet if args.split_dataset_ratio and args.split_dataset_ratio > 0: logger.warning( @@ -247,46 +242,8 @@ def _prepare_dataset(args, template): logger.warning('Ray pipeline has no validation loop yet; ignoring val_dataset=%s.', args.val_dataset) args.val_dataset = [] - pre_process = args.rlhf_type not in ('grpo', 'gkd') - train_datasets, val_datasets = [], [] - - if args.dataset or args.val_dataset: - dataset_kwargs = args.get_dataset_kwargs() - train_dataset, val_dataset = None, None - if args.dataset: - train_dataset, val_dataset = load_dataset( - args.dataset, - split_dataset_ratio=args.split_dataset_ratio, - shuffle=args.dataset_shuffle, - **dataset_kwargs) - if args.val_dataset: - _, val_dataset = load_dataset( - args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs) - - if not pre_process: - return train_dataset, val_dataset - - from swift.dataset import AddLengthPreprocessor - for i, ds in enumerate([train_dataset, val_dataset]): - if ds is None: - continue - if not args.lazy_tokenize and not args.streaming: - preprocessor = AddLengthPreprocessor(template=template) - batch_size = 100 if args.model_meta.is_multimodal else 1000 - ds = preprocessor( - ds, - num_proc=args.dataset_num_proc, - load_from_cache_file=args.load_from_cache_file, - strict=args.strict, - batch_size=batch_size) - if i == 0: - train_datasets.append(ds) - else: - val_datasets.append(ds) - - train_dataset = DatasetLoader.concat_datasets(train_datasets) if train_datasets else None - val_dataset = DatasetLoader.concat_datasets(val_datasets) if val_datasets else None - return train_dataset, val_dataset + assert args.rlhf_type in ('grpo', 'gkd') + return args.load_dataset() def compute_iter_params(data_info: Dict[str, Any], dp_size: int) -> Dict[str, Any]: From 92cbe2e8ce42c3e6534e87d4ad1f485a13192132 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 12 Jun 2026 13:54:37 +0800 Subject: [PATCH 3/3] update --- swift/ray/megatron/driver_utils.py | 3 +-- swift/utils/__init__.py | 6 +++--- swift/utils/utils.py | 6 ++++++ 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/swift/ray/megatron/driver_utils.py b/swift/ray/megatron/driver_utils.py index a7f5ddc89a..b4a7aa0ef3 100644 --- a/swift/ray/megatron/driver_utils.py +++ b/swift/ray/megatron/driver_utils.py @@ -10,6 +10,7 @@ * extracting the canonical iteration from worker results """ import json +import os import torch from dataclasses import dataclass, field from typing import Any, Dict, List @@ -176,8 +177,6 @@ def build_dataset_from_dict(cfg: Dict[str, Any]): ``get_template``, ``get_dataset_kwargs``) so no distributed init or Megatron-specific ``__post_init__`` logic is triggered. """ - import os - from swift.megatron.arguments import MegatronRLHFArguments from swift.rlhf_trainers.utils import identity_data_collator cfg = dict(cfg) diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index fe19504c5c..37e0a608f1 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -28,6 +28,6 @@ unwrap_model_for_generation) from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port, find_node_ip, format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound, - parse_args, patch_getattr, read_multi_line, remove_response, retry_decorator, seed_everything, - shutdown_event_loop_in_daemon, split_list, start_event_loop_in_daemon, subprocess_run, test_time, - to_abspath, upper_bound) + parse_args, parse_args_from_dict, patch_getattr, read_multi_line, remove_response, retry_decorator, + seed_everything, shutdown_event_loop_in_daemon, split_list, start_event_loop_in_daemon, + subprocess_run, test_time, to_abspath, upper_bound) diff --git a/swift/utils/utils.py b/swift/utils/utils.py index 7f2e5c66af..f39309cb27 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -183,6 +183,12 @@ def parse_args(class_type: Type[_T], argv: Optional[List[str]] = None) -> Tuple[ return args, remaining_args +def parse_args_from_dict(class_type: Type[_T], args: Dict[str, Any]) -> _T: + with _patch_get_type_hints(): + parser = HfArgumentParser([class_type]) + return parser.parse_dict(args, allow_extra_keys=True)[0] + + def lower_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int: # The lower bound satisfying the condition "cond". while lo < hi: