Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions example/common/utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "example/common/utils.h"

#include "gflags/gflags.h"
#include "glog/logging.h"

namespace infini_train {

float ConvertBF16ToFloat(void *ptr) {
Expand Down Expand Up @@ -61,4 +64,11 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
ifs.seekg(base + std::streamoff(len * sizeof(float)));
}

void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer) {
gflags::CommandLineFlagInfo zero_stage_info;
CHECK(gflags::GetCommandLineFlagInfo("zero_stage", &zero_stage_info));
CHECK(use_distributed_optimizer || zero_stage_info.is_default)
<< "--zero_stage requires --use_distributed_optimizer=true.";
}

} // namespace infini_train
2 changes: 2 additions & 0 deletions example/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);

void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);

void ValidateDistributedOptimizerFlags(bool use_distributed_optimizer);

} // namespace infini_train
11 changes: 8 additions & 3 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/common/utils.h"
#include "example/gpt2/checkpoint_loader.h"
#include "example/gpt2/config.h"

Expand All @@ -58,6 +59,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -114,6 +116,7 @@ const std::unordered_map<std::string, nn::TransformerConfig> kModelToConfigs = {
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -252,8 +255,8 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
Expand All @@ -265,7 +268,8 @@ void Train(const nn::parallel::Rank &rank) {
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

Expand Down Expand Up @@ -447,6 +451,7 @@ void Train(const nn::parallel::Rank &rank) {
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer);

auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
Expand Down
11 changes: 8 additions & 3 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/common/utils.h"
#include "example/llama3/checkpoint_loader.h"
#include "example/llama3/config.h"

Expand All @@ -57,6 +58,7 @@ DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
DEFINE_int32(zero_stage, 1, "ZeRO stage (1/2/3), default 1 (only take effects when use_distributed_optimizer=true)");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那应该加个检查,如果没开 use_distributed_optimizer但设置了zero_stage直接报错,而不是静默

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实是,我加上

// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -100,6 +102,7 @@ constexpr char kDtypeBF16[] = "bfloat16";
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(zero_stage, [](const char *, int32_t value) { return value >= 1 && value <= 3; });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -222,8 +225,8 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, model_config.GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{
.use_distributed_optimizer = FLAGS_use_distributed_optimizer, .zero_stage = FLAGS_zero_stage};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
Expand All @@ -236,7 +239,8 @@ void Train(const nn::parallel::Rank &rank) {
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.

auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer,
.zero_stage = FLAGS_zero_stage};
model = std::make_shared<DistributedDataParallel>(model, rank, ddp_config);
}

Expand Down Expand Up @@ -422,6 +426,7 @@ void Train(const nn::parallel::Rank &rank) {
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
ValidateDistributedOptimizerFlags(FLAGS_use_distributed_optimizer);

auto precision_config = utils::PrecisionCheckConfig::Parse(FLAGS_precision_check);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
Expand Down
6 changes: 6 additions & 0 deletions infini_train/include/autograd/function_hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ namespace infini_train::autograd {
class PostAccumulateGradHook {
public:
virtual void operator()(const std::shared_ptr<Tensor> &tensor) = 0;

// ZeRO-2: Use this function to take over AccumulateGrad::Backward
virtual bool TryBypassAccumulate(const std::shared_ptr<Tensor> &, const std::shared_ptr<Tensor> &, bool, float) {
return false;
}

virtual ~PostAccumulateGradHook() = default;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DistributedDataParallelConfig {
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
bool overlap_grad_reduce = true;

// ZeRO-DP Stage for memory optimization (Only take effects when use_distributed_optimizer=true)
// ZeRO-1: Optimizer states partitioning, by default
// ZeRO-2: Gradients partitioning
// ZeRO-3: Parameters partitioning
int zero_stage = 1;

// Whether to overlap parameter all-gather with forward compute.
bool overlap_param_gather = true;

Expand All @@ -59,7 +65,7 @@ class DistributedDataParallelConfig {
// Maximum number of parameters in each ParamAndGradBucket.
// NOTE(zbl): This is distinct from DDP Reducer's MB-based bucket caps.
// TODO(zbl): To unify the definition of bucket_size argument for users
size_t bucket_size_in_elements = 40000000;
size_t bucket_size_in_elements = 1000000;

// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
bool pad_buckets_for_high_nccl_busbw = false;
Expand Down
25 changes: 23 additions & 2 deletions infini_train/include/nn/parallel/ddp/param_and_grad_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace infini_train::nn::parallel {
class ParamAndGradBucket {
public:
ParamAndGradBucket(const std::vector<std::shared_ptr<Tensor>> &params, const std::shared_ptr<Tensor> &param_data,
const std::shared_ptr<Tensor> &grad_data, size_t offset, size_t num_elements_unpadded,
float gradient_scaling_factor, size_t bucket_id);
DataType param_dtype, const std::shared_ptr<Tensor> &grad_data, DataType grad_dtype,
size_t offset, size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id);

size_t bucket_id() const { return bucket_id_; }

Expand All @@ -33,6 +33,10 @@ class ParamAndGradBucket {

const std::shared_ptr<Tensor> &grad_data() const { return grad_data_; }

DataType param_dtype() const { return param_dtype_; }

DataType grad_dtype() const { return grad_dtype_; }

size_t offset() const { return offset_; }

size_t num_elements_unpadded() const { return num_elements_unpadded_; }
Expand All @@ -49,6 +53,8 @@ class ParamAndGradBucket {
std::vector<std::shared_ptr<Tensor>> params_;
std::shared_ptr<Tensor> param_data_;
std::shared_ptr<Tensor> grad_data_;
DataType param_dtype_;
DataType grad_dtype_;

size_t offset_ = 0;
size_t num_elements_unpadded_ = 0;
Expand All @@ -73,6 +79,11 @@ class ParamAndGradBucketGroup {
// Start grad reduce
void StartGradSync();

// Accumulate a parameter grad into bucket buffer
// ZeRO-2: Use this funtion to take over autograd::AccumulateGrad::Backward
void AccumulateParamGrad(const std::shared_ptr<Tensor> &parameter, const std::shared_ptr<Tensor> &grad,
bool overwrite, float learning_rate);

// Wait for gradient reduce to complete
void FinishGradSync();

Expand All @@ -87,6 +98,9 @@ class ParamAndGradBucketGroup {

const std::vector<std::shared_ptr<ParamAndGradBucket>> &buckets() const { return buckets_; }

// ZeRO-2: Get a bucket's local grad shard buffer
std::shared_ptr<Tensor> GetLocalGradShardBuffer(size_t bucket_idx) const;

const DistributedDataParallelConfig &config() const { return ddp_config_; }

private:
Expand All @@ -98,12 +112,19 @@ class ParamAndGradBucketGroup {

std::unordered_set<Tensor *> params_;
std::unordered_set<Tensor *> params_with_grad_;
// Tensor -> (Bucket, Bucket Index)
std::unordered_map<Tensor *, std::pair<std::shared_ptr<ParamAndGradBucket>, size_t>> param_to_bucket_;

// TODO(zbl): Implement CoalescedWork for aggregate works
// According to Megatron-LM's _coalescing_manager
std::vector<std::shared_ptr<Work>> grad_reduce_work_list_;
std::vector<size_t> grad_reduce_bucket_indices_;
std::vector<std::shared_ptr<Work>> param_gather_work_list_;

// ZeRO-2: persistent grad shard buffers and temporary full grad buffers
std::vector<std::shared_ptr<Tensor>> grad_shard_buffer_list_;
std::vector<std::shared_ptr<Tensor>> temp_full_grad_buffer_list_;

std::shared_ptr<ParamAndGradBucketGroup> next_param_gather_bucket_group_ = nullptr;

std::vector<std::vector<std::shared_ptr<Tensor>>> param_buffer_shard_list_;
Expand Down
10 changes: 8 additions & 2 deletions infini_train/src/autograd/accumulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,15 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
"running before autograd). The grad is not cast and will be used as-is.";
}

const bool overwrite = tensor_->ConsumeGradOverwriteFlag();
auto hook = tensor_->post_accumulate_grad_hook();
if (hook && hook->TryBypassAccumulate(tensor_, grad_output, overwrite, learning_rate_)) {
tensor_->ResetAccumulator();
return {};
}

if (grad) {
if (tensor_->ConsumeGradOverwriteFlag()) {
if (overwrite) {
// If the tensor is marked to overrite its current grad on next grad update
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
// NOTE(zbl): must copy, cannot change grad buffer address
Expand All @@ -48,7 +55,6 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
auto new_grad = std::make_shared<Tensor>(*grad_output.get(), 0, grad_output->Dims());
tensor_->set_grad(new_grad);
}
auto hook = tensor_->post_accumulate_grad_hook();
if (hook != nullptr) {
(*hook)(tensor_->grad());
}
Expand Down
54 changes: 53 additions & 1 deletion infini_train/src/nn/parallel/ddp/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
const DistributedDataParallelConfig ddp_config)
: ddp_config_(ddp_config),
ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(rank.GlobalRank()))) {
CHECK(ddp_config_.zero_stage >= 1 && ddp_config_.zero_stage <= 3)
<< "DistributedDataParallel: zero_stage must be in 1/2/3.";
if (ddp_config_.zero_stage >= 3) {
LOG(FATAL) << "DistributedDataParallel: ZeRO-3 is not implemented yet.";
}
if (!ddp_config_.use_distributed_optimizer && ddp_config_.zero_stage >= 1) {
LOG(WARNING) << "DistributedDataParallel: zero_stage is ignored because "
"use_distributed_optimizer is false.";
ddp_config_.zero_stage = 1;
}
for (auto &param : module->Parameters()) {
if (!param->requires_grad()) {
continue;
Expand Down Expand Up @@ -83,6 +93,7 @@ void DistributedDataParallel::BuildParamAndGradBuffers() {
continue;
}

// At the point, zero_stage is already aligned with use_distributed_optimizer.
auto buffer = std::make_shared<ParamAndGradBuffer>(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_);

param_grad_buffers_.push_back(buffer);
Expand Down Expand Up @@ -116,6 +127,47 @@ void DistributedDataParallel::BuildParamAndGradBuffers() {
}

void DistributedDataParallel::RegisterBackwardHooks() {
if (ddp_config_.zero_stage >= 2) {
// NOTE(zbl): ZeRO-2 bypasses Tensor::grad accumulation: stash grads in the bucket group's
// temporary full-grad buffer, then mark the bucket ready for reduce-scatter.
class Zero2AccumulateGradHook final : public autograd::PostAccumulateGradHook {
public:
explicit Zero2AccumulateGradHook(std::weak_ptr<ParamAndGradBucketGroup> group) : group_(std::move(group)) {}

bool TryBypassAccumulate(const std::shared_ptr<Tensor> &param, const std::shared_ptr<Tensor> &grad_output,
bool overwrite, float learning_rate) override {
if (auto group = group_.lock(); group) {
group->AccumulateParamGrad(param, grad_output, overwrite, learning_rate);
if (group->config().overlap_grad_reduce) {
group->RegisterGradReady(param);
}
return true;
}
return false;
}

void operator()(const std::shared_ptr<Tensor> &) override {}

private:
std::weak_ptr<ParamAndGradBucketGroup> group_;
};

auto &module = modules_.at(kModuleName);
for (auto &param : module->Parameters()) {
if (!param->requires_grad()) {
continue;
}
auto it = param_to_bucket_group_.find(param.get());
if (it == param_to_bucket_group_.end()) {
continue;
}
std::weak_ptr<ParamAndGradBucketGroup> weak_group = it->second;
auto hook = std::make_unique<Zero2AccumulateGradHook>(weak_group);
param->RegisterPostAccumulateGradHook(std::move(hook));
}
return;
}

class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook {
public:
DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr<Tensor> param)
Expand Down Expand Up @@ -147,7 +199,7 @@ void DistributedDataParallel::OnGradReady(const std::shared_ptr<Tensor> &param)
auto it = param_to_bucket_group_.find(param.get());
if (it != param_to_bucket_group_.end()) {
CHECK(param->requires_grad());
if (ddp_config_.overlap_grad_reduce) {
if (ddp_config_.overlap_grad_reduce && (ddp_config_.zero_stage < 2)) {
CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True";
}

Expand Down
Loading
Loading