2. deepspeed - 总入口

deepspeed 总的入口在 deepspeed.__init__::initialize :

def initialize(args=None,
           model: torch.nn.Module = None,
           optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,
           model_parameters: Optional[torch.nn.Module] = None,
           training_data: Optional[torch.utils.data.Dataset] = None,
           lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,
           mpu=None,
           dist_init_required: Optional[bool] = None,
           collate_fn=None,
           config=None,
           config_params=None):
"""Initialize the DeepSpeed Engine.

Arguments:
    args: an object containing local_rank and deepspeed_config fields.
        This is optional if `config` is passed.

    model: Required: nn.module class before apply any wrappers

    optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.
        This overrides any optimizer definition in the DeepSpeed json config.

    model_parameters: Optional: An iterable of torch.Tensors or dicts.
        Specifies what Tensors should be optimized.

    training_data: Optional: Dataset of type torch.utils.data.Dataset

    lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
        The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods

    mpu: Optional: A model parallelism unit object that implements
        get_{model,data}_parallel_{rank,group,world_size}()

    dist_init_required: Optional: None will auto-initialize torch distributed if needed,
        otherwise the user can force it to be initialized or not via boolean.

    collate_fn: Optional: Merges a list of samples to form a
        mini-batch of Tensor(s).  Used when using batched loading from a
        map-style dataset.

    config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed config
        as an argument instead, as a path or a dictionary.

    config_params: Optional: Same as `config`, kept for backwards compatibility.

Returns:
    A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``

    * ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.

    * ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or if
      optimizer is specified in json config else ``None``.

    * ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,
      otherwise ``None``.

    * ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, or
      if ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``.
"""
log_dist("DeepSpeed info: version={}, git-hash={}, git-branch={}".format(__version__, __git_hash__,
                                                                         __git_branch__),
         ranks=[0])

# Disable zero.Init context if it's currently enabled
zero.partition_parameters.shutdown_init_context()

assert model is not None, "deepspeed.initialize requires a model"

global dist
from deepspeed import comm as dist
# 底层分布式通信引擎支持多种选择,有 ['cuda', 'cpu', 'xpu', 'npu', 'mps'] 几种选择
# 当然大部分情况我们都选择 cuda
# 可以通过环境变量 DS_ACCELERATOR 指定
dist_backend = get_accelerator().communication_backend_name()
dist.init_distributed(dist_backend=dist_backend, dist_init_required=dist_init_required)

# Set config using config_params for backwards compat
if config is None and config_params is not None:
    config = config_params

# Check for deepscale_config for backwards compat
if hasattr(args, "deepscale_config") and args.deepscale_config is not None:
    logger.warning("************ --deepscale_config is deprecated, please use --deepspeed_config ************")
    if hasattr(args, "deepspeed_config"):
        assert (args.deepspeed_config is
                None), "Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
    args.deepspeed_config = args.deepscale_config
    args.deepscale_config = None

# Check that we have only one config passed
if hasattr(args, "deepspeed_config") and args.deepspeed_config is not None:
    assert config is None, "Not sure how to proceed, we were given deepspeed configs in the deepspeed arguments and deepspeed.initialize() function call"
    config = args.deepspeed_config
assert config is not None, "DeepSpeed requires --deepspeed_config to specify configuration file"
if not isinstance(model, PipelineModule):
# 首先,不是流水线模块
    config_class = DeepSpeedConfig(config, mpu)
    if config_class.hybrid_engine.enabled:
        # 混合引擎(Hybrid Engine)。它利用原始DeepSpeed引擎进行高速训练模式,
        # 同时轻松应用DeepSpeed推理引擎进行生成/评估模式,
        # 为第三阶段的RLHF训练提供了一个明显更快的训练系统
        engine = DeepSpeedHybridEngine(args=args,
                                       model=model,
                                       optimizer=optimizer,
                                       model_parameters=model_parameters,
                                       training_data=training_data,
                                       lr_scheduler=lr_scheduler,
                                       mpu=mpu,
                                       dist_init_required=dist_init_required,
                                       collate_fn=collate_fn,
                                       config=config,
                                       config_class=config_class)
    else:
        # 一般情况下是这里,也是我们重点要梳理的
        engine = DeepSpeedEngine(args=args,
                                 model=model,
                                 optimizer=optimizer,
                                 model_parameters=model_parameters,
                                 training_data=training_data,
                                 lr_scheduler=lr_scheduler,
                                 mpu=mpu,
                                 dist_init_required=dist_init_required,
                                 collate_fn=collate_fn,
                                 config=config,
                                 config_class=config_class)
else:
    assert mpu is None, "mpu must be None with pipeline parallelism"
    mpu = model.mpu()
    config_class = DeepSpeedConfig(config, mpu)
    # 如果是流水线模式,就用 PipelineEngine
    engine = PipelineEngine(args=args,
                            model=model,
                            optimizer=optimizer,
                            model_parameters=model_parameters,
                            training_data=training_data,
                            lr_scheduler=lr_scheduler,
                            mpu=mpu,
                            dist_init_required=dist_init_required,
                            collate_fn=collate_fn,
                            config=config,
                            config_class=config_class)

# Restore zero.Init context if necessary
zero.partition_parameters.restore_init_context()

return_items = [engine, engine.optimizer, engine.training_dataloader, engine.lr_scheduler]
return tuple(return_items)

这个入口只是一个代理,根据不同情况选择三种模式之一

  • 流水线引擎(PipelineEngine),这个模式先不管他,后续有机会再分析,

  • 混合引擎(DeepSpeedHybridEngine),可以同时进行训练和推理,为了应对 RLHF 训练定制的。

  • 一般模式(DeepSpeedEngine),这个是最基本的模式,分布式训练引擎,也是我们本次重点分析的模式。

DeepSpeedEngine 的实现在 deepspeed.runtime.engine 中, 它本身是 torch.nn.Module 的子类,也就是说它是对输入模型的一个封装。 DeepSpeedEngine__init__ 方法中进行了大量的初始化操作, 其中最重要的就是对优化器(Optimizer)的初始化, ZeRO 的核心特性的实现都在优化器(Optimizer)中。

** DeepSpeedHybridEngine **

混合引擎(Hybrid Engine)。它利用原始DeepSpeed引擎进行高速训练模式, 同时轻松应用DeepSpeed推理引擎进行生成/评估模式, 为第三阶段的RLHF训练提供了一个明显更快的训练系统。

2.1. 优化器的初始化

class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training."""

    def __init__(
            self,
            args,
            model,
            optimizer=None,
            model_parameters=None,
            training_data=None,
            lr_scheduler=None,
            mpu=None,
            dist_init_required=None,
            collate_fn=None,
            config=None,
            config_class=None,
            dont_change_device=False,
    ):
        ...
        # 优化器的初始化
        if has_optimizer:  # 入参传入了 optimizer 或者配置文件中指定了 optimizer
            self._configure_optimizer(optimizer, model_parameters)
            self._configure_lr_scheduler(lr_scheduler)
            self._report_progress(0)
        elif self.zero_optimization():  # 启用 zero 优化,即 zero_optimization_stage > 0
            # no optim selected but zero is enabled
            self.optimizer = self._configure_zero_optimizer(optimizer=None)
        elif self.bfloat16_enabled(): # bf16 模式
            self.optimizer = self._configure_bf16_optimizer(optimizer=None)

        ...

注意,这三种情况是互斥的。

2.1.1. 基础优化器

如果入参传入了 optimizer 或者配置文件指定了 optimizer, 则进入方法 self._configure_optimizer(optimizer, model_parameters) 。 这个方法是先根据入参或者配置文件创建基础优化器,然后再调用高级优化器。

def _configure_optimizer(self, client_optimizer, model_parameters):

    # 首先根据入参或者配置文件创建和初始化基本的优化器
    if client_optimizer is not None:
        # 入参传入了优化器,
        if isinstance(client_optimizer, tuple(self._supported_optims())):
            client_optimizer.param_groups[:] = [
                pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
            ]
            log_dist("Removing param_group that has no 'params' in the client Optimizer", ranks=[0])

            basic_optimizer = client_optimizer
            log_dist('Using client Optimizer as basic optimizer', ranks=[0])
        else:
            basic_optimizer = client_optimizer(model_parameters)
            log_dist('Using client callable to create basic optimizer', ranks=[0])

        if self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam):
            if self.zero_force_ds_cpu_optimizer():
                msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
                raise ZeRORuntimeException(msg)
    else:
        # 根据配置文件的信息创建基础优化器
        basic_optimizer = self._configure_basic_optimizer(model_parameters)
        log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0])

    self._check_for_duplicates(basic_optimizer)
    # 原始的优化器,可以是来自 torch.optim 的优化器,也可以是 deepspeed 改版的 cpu 实现版本
    self.basic_optimizer = basic_optimizer
    log_dist("DeepSpeed Basic Optimizer = {}".format(basic_optimizer.__class__.__name__), ranks=[0])
    # optimizer_wrapper:str in ["fp16","bf16","zero_optimization","amp"]
    optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer)
    # "fp16","bf16","zero_optimization","amp" 这几个是互斥的,只能选择一个
    if optimizer_wrapper == ZERO_OPTIMIZATION:  # 启用 ZeRO 优化,意味着 stage>0
        self.optimizer = self._configure_zero_optimizer(basic_optimizer)
    elif optimizer_wrapper == AMP:  # 启用自动混合精度
        amp_params = self.amp_params()
        log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0])
        model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params)
        self._set_client_model(model)
        self._broadcast_model()
        # TODO: maybe need to broadcast experts differently?
    elif optimizer_wrapper == FP16:  # 启用 FP16 半精度优化器
        self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
    elif optimizer_wrapper == BFLOAT16:  # 启用 BFP16 半精度优化器
        self.optimizer = self._configure_bf16_optimizer(basic_optimizer)
    else:  # 啥都不启用
        self.optimizer = basic_optimizer

    log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0])

    self.compression_scheduler = self._configure_compression_scheduler()
    self.quantizer = self._configure_quantization()

这个方法核心就两件事:

  1. 根据入参或者配置文件,创建一个原始的基本优化器

  2. 根据配置,把基本优化器二次封装高级优化器,可以是 ZeRO 优化、自动混合精度、FP16半精度、BFLOAT16半精度其中之一。

2.1.2. 创建 ZeRO 优化器

创建 ZeRO 优化器,核心实现在方法 _configure_zero_optimizer 中, 这个方法接收一个基础优化器,然后根据不同的等级选取对应的 ZeRO 优化器。

  • stage 1,已经废弃。

  • stage 2,对应实现 DeepSpeedZeroOptimizer

  • stage 3,对应实现 DeepSpeedZeroOptimizer_Stage3 ,最终还是会跳到 DeepSpeedZeRoOffload, 有关 stage3 优化器的跟踪请跳转到 节 4

def _configure_zero_optimizer(self, optimizer):
    # ZeRO 开启等级
    zero_stage = self.zero_optimization_stage()

    mics_shard_size = self.mics_shard_size()
    # model_dtype 模型参数的数据类型
    # grad_accum_dtype 梯度累积的数据类型 : [“fp32”  “fp16”  “bf16”]
    model_dtype, gradient_accumulation_dtype = self.get_data_types()

    timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
    # 如果没有基础分类器,创建一个假的
    if optimizer is None:
        optimizer = DummyOptim(list(self.module.parameters()))

    if self.zero_legacy_stage1():  # stage1 已经废弃了,至少从 stage2 起步
        raise Exception(
            "The deprecated version of ZeRO Stage 1 is not supported in deepspeed >= 0.5.9. Please downgrade to a version less than 0.5.9 if you need to use this deprecated version of ZeRO."
        )

    if zero_stage <= ZeroStageEnum.gradients:  # stage <=2
        # 配置项 默认为 False
        # 尝试将梯度缩减与逆向计算相重叠
        overlap_comm = self.zero_overlap_comm()
        # 配置项 默认为 True
        # 在生成梯度时将其复制到连续的缓冲区中。避免了后向传递过程中的内存碎片。
        contiguous_gradients = self.zero_contiguous_gradients()
        # 配置项 默认为 False
        # 针对 CPU 卸载的第 1 和第 2 阶段优化,通过细粒度梯度分区,将梯度复制到 CPU 内存的过程并行化。
        # 性能优势随着梯度累积步骤(优化器步骤之间的复制次数增加)或 GPU 数量(并行性增加)的增加而增加。
        round_robin_gradients = self.zero_round_robin_gradients()
        assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage)

        log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
        # Overlap and contiguous grads are meaningless in stage 1 and are ignored
        if zero_stage == ZeroStageEnum.optimizer_states:  # stage==1
            # stage 1 不支持这些特性,需要关闭
            overlap_comm = False
            round_robin_gradients = False
            # Non-MoE requires contiguous grads to be disabled w. stage 1
            if not self.has_moe_layers:
                contiguous_gradients = False

        if isinstance(self.module, PipelineModule):
            if overlap_comm:
                logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.")
                overlap_comm = False
        optimizer = DeepSpeedZeroOptimizer(
            optimizer,
            self.param_names,
            timers=timers,
            static_loss_scale=self.loss_scale(),
            dynamic_loss_scale=self.dynamic_loss_scale(),
            dynamic_loss_args=self.dynamic_loss_scale_args(),
            clip_grad=self.gradient_clipping(),
            contiguous_gradients=contiguous_gradients,
            reduce_bucket_size=self.zero_reduce_bucket_size(),
            allgather_bucket_size=self.zero_allgather_bucket_size(),
            dp_process_group=self.data_parallel_group,
            expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None,
            expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None,
            reduce_scatter=self.zero_reduce_scatter(),
            overlap_comm=overlap_comm,
            offload_optimizer_config=self.zero_offload_optimizer(),
            mpu=self.mpu,
            postscale_gradients=self.postscale_gradients(),
            gradient_predivide_factor=self.gradient_predivide_factor(),
            gradient_accumulation_steps=self.gradient_accumulation_steps(),
            ignore_unused_parameters=self.zero_ignore_unused_parameters(),
            partition_grads=zero_stage == ZeroStageEnum.gradients,
            round_robin_gradients=round_robin_gradients,
            has_moe_layers=self.has_moe_layers,
            fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
            gradient_accumulation_dtype=gradient_accumulation_dtype,
            communication_data_type=self.communication_data_type,
            elastic_checkpoint=self.zero_elastic_checkpoint())

    elif zero_stage == ZeroStageEnum.weights:  # stage ==3
        assert not self.has_moe_layers, "MoE not supported with Stage 3"
        if isinstance(optimizer, DummyOptim):
            log_dist("Creating ZeRO Offload", ranks=[0])
            zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
            if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None:
                self._set_zero_group_parallelism()
                zero_param_parallel_group = groups._get_zero_param_intra_parallel_group()
            optimizer = DeepSpeedZeRoOffload(
                self.module,
                timers=timers,
                ds_config=self.config,
                overlap_comm=self.zero_overlap_comm(),
                prefetch_bucket_size=self.zero_prefetch_bucket_size(),
                max_reuse_distance=self.zero_max_reuse_distance(),
                max_live_parameters=self.zero_max_live_parameters(),
                param_persistence_threshold=self.zero_param_persistence_threshold(),
                model_persistence_threshold=self.zero_model_persistence_threshold(),
                offload_param_config=self.zero_offload_param(),
                mpu=self.mpu,
                zero_param_parallel_group=zero_param_parallel_group,
                zero_quantized_weights=self.zero_quantized_weights(),
                zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
            )
        else:
            log_dist(
                f'Creating fp16 ZeRO stage {zero_stage} optimizer,'
                f' MiCS is enabled {mics_shard_size > 0},'
                f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}',
                ranks=[0])
            if mics_shard_size > 0:
                return self._return_mics_optimizer(optimizer, timers)
            # stage 3 的优化器
            log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
            from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
            optimizer = DeepSpeedZeroOptimizer_Stage3(
                self.module,
                optimizer,
                timers=timers,
                ds_config=self.config,
                static_loss_scale=self.loss_scale(),
                dynamic_loss_scale=self.dynamic_loss_scale(),
                dynamic_loss_args=self.dynamic_loss_scale_args(),
                clip_grad=self.gradient_clipping(),
                contiguous_gradients=self.zero_contiguous_gradients(),
                reduce_bucket_size=self.zero_reduce_bucket_size(),
                prefetch_bucket_size=self.zero_prefetch_bucket_size(),
                max_reuse_distance=self.zero_max_reuse_distance(),
                max_live_parameters=self.zero_max_live_parameters(),
                param_persistence_threshold=self.zero_param_persistence_threshold(),
                model_persistence_threshold=self.zero_model_persistence_threshold(),
                dp_process_group=self.data_parallel_group,
                all2all_process_group=self.local_all_to_all_group,
                reduce_scatter=self.zero_reduce_scatter(),
                overlap_comm=self.zero_overlap_comm(),
                offload_optimizer_config=self.zero_offload_optimizer(),
                offload_param_config=self.zero_offload_param(),
                sub_group_size=self.zero_sub_group_size(),
                mpu=self.mpu,
                postscale_gradients=self.postscale_gradients(),
                gradient_predivide_factor=self.gradient_predivide_factor(),
                gradient_accumulation_steps=self.gradient_accumulation_steps(),
                aio_config=self.aio_config(),
                gradient_accumulation_dtype=gradient_accumulation_dtype,
                communication_data_type=self.communication_data_type,
                zero_hpz_partition_size=self.zero_hpz_partition_size(),
                zero_quantized_weights=self.zero_quantized_weights(),
                zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
            )

    else:
        raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))

    return optimizer

2.1.3. 创建 f16 半精度优化器

2.1.4. 创建 bf16 半精度优化器