4. Stage3 - 参数分割

DeepSpeedEngine 引擎中,会根据配置创建对应阶段的优化器实例, Stage-3 阶段对应的优化器是类 DeepSpeedZeroOptimizer_Stage3 , 我们从这个类开始下钻分析。

4.1. DeepSpeedZeroOptimizer_Stage3

先看下这个类来自哪里,

from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3

接下来,看下这个类的 __init__ 方法,

def __init__(
    self,
    module,   # 待处理的模型
    init_optimizer,  # 基础优化器
    timers,
    ds_config,  # 配置信息
    static_loss_scale=1.0,
    dynamic_loss_scale=False,
    dynamic_loss_args=None,
    verbose=True,
    contiguous_gradients=True,
    reduce_bucket_size=500000000,
    prefetch_bucket_size=50000000,
    max_reuse_distance=1000000000,
    max_live_parameters=1000000000,
    param_persistence_threshold=100000,
    model_persistence_threshold=sys.maxsize,
    dp_process_group=None,
    reduce_scatter=True,
    overlap_comm=False,
    offload_optimizer_config=None,
    offload_param_config=None,
    sub_group_size=1000000000000,
    mpu=None,
    clip_grad=0.0,
    gradient_accumulation_dtype=torch.float32,
    communication_data_type=torch.float16,
    postscale_gradients=True,
    gradient_predivide_factor=1.0,
    gradient_accumulation_steps=1,
    elastic_checkpoint=False,
    aio_config=None,
    all2all_process_group=None,
    zero_hpz_partition_size=1,
    zero_quantized_weights=False,
    zero_quantized_nontrainable_weights=False,
):

 ...
 self.optimizer = init_optimizer  # 基础优化器,可能是 torch的优化器对象
 ...

 # 参数的处理核心逻辑在这里
 # self.parameter_offload 是一个 DeepSpeedZeRoOffload 对象
 self.parameter_offload : DeepSpeedZeRoOffload = self.initialize_ds_offload(
        module=module,
        timers=timers,
        ds_config=ds_config,
        overlap_comm=overlap_comm,
        prefetch_bucket_size=prefetch_bucket_size,
        max_reuse_distance=max_reuse_distance,
        max_live_parameters=max_live_parameters,
        param_persistence_threshold=param_persistence_threshold,
        model_persistence_threshold=model_persistence_threshold,
        offload_param_config=offload_param_config,
        mpu=mpu,
        zero_param_parallel_group=zero_param_parallel_group,
        zero_quantized_weights=zero_quantized_weights,
        zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights)

 ...

顺着跟踪,看下类方法 initialize_ds_offload 的实现, 它其实就只是创建了一个 DeepSpeedZeRoOffload 的对象。

def initialize_ds_offload(
    self,
    module,
    timers,
    ds_config,
    overlap_comm,
    prefetch_bucket_size,
    max_reuse_distance,
    max_live_parameters,
    param_persistence_threshold,
    model_persistence_threshold,
    offload_param_config,
    mpu,
    zero_param_parallel_group,
    zero_quantized_weights,
    zero_quantized_nontrainable_weights,
):
    return DeepSpeedZeRoOffload(module=module,
                                timers=timers,
                                ds_config=ds_config,
                                overlap_comm=overlap_comm,
                                prefetch_bucket_size=prefetch_bucket_size,
                                max_reuse_distance=max_reuse_distance,
                                max_live_parameters=max_live_parameters,
                                param_persistence_threshold=param_persistence_threshold,
                                model_persistence_threshold=model_persistence_threshold,
                                offload_param_config=offload_param_config,
                                mpu=mpu,
                                zero_param_parallel_group=zero_param_parallel_group,
                                zero_quantized_weights=zero_quantized_weights,
                                zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights)

4.2. DeepSpeedZeRoOffload

Stage3 中对参数划分的的核心逻辑在 类 DeepSpeedZeRoOffload 中实现, 这个类在文件 deepspeed.runtime.zero.parameter_offload 中, 同样,查看这个类的初始化方法。

class DeepSpeedZeRoOffload(object):

    def __init__(
        self,
        module,
        timers,
        ds_config,
        overlap_comm=True,
        prefetch_bucket_size=50000000,
        max_reuse_distance=1000000000,
        max_live_parameters=1000000000,
        param_persistence_threshold=100000,
        model_persistence_threshold=sys.maxsize,
        offload_param_config=None,
        mpu=None,
        zero_param_parallel_group=None,
        zero_quantized_weights=False,
        zero_quantized_nontrainable_weights=False,
    ):
        ...
        # 参数分割的逻辑
        self._convert_to_zero_parameters(ds_config, module, mpu)

        for m in module.modules():
            _init_external_params(m)

        _inject_parameters(module, ZeROOrderedDict)

        ...

__init__ 方法中调用了方法 _convert_to_zero_parameters , 从方法名字也能看出就是我们想要的。

def _convert_to_zero_parameters(self, ds_config, module, mpu):
    # 没有 ds_id 属性的参数,即还没有被 deepspeed 处理过的参数
    non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
    if non_zero_params: # 存在没有处理过的参数

        zero_params = [p for p in module.parameters() if is_zero_param(p)]
        if zero_params: # 存已经处理过参数
            # 这意味着这个模型,之前被初始化过,现在额外多了一些未处理的新参数
            # 为了保持一致,调用原先的处理过程
            zero_params[0].convert_to_zero_parameters(param_list=non_zero_params)
        else: # 模型参数是首次初始化
            group = None
            if mpu:
                group = mpu.get_data_parallel_group()
            # 核心的参数分割过程在这里,来自 deepspeed.runtime.zero.partition_parameters.Init
            Init(module=module,
                 data_parallel_group=group,
                 dtype=self.dtype,
                 config_dict_or_path=ds_config,
                 remote_device=self.offload_device,
                 pin_memory=self.offload_param_pin_memory,
                 mpu=mpu,
                 zero_param_parallel_group=self.zero_param_parallel_group,
                 zero_quantized_weights=self.zero_quantized_weights,
                 zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights)

这个方法最终又跳到了 Init 类,创建了 Init 实例, 真实够绕的。

4.3. Init 模块

这个 Init 类来自

from deepspeed.runtime.zero.partition_parameters import Init

先看官方的注释:

A context to enable massive model construction for training with ZeRO-3.
Models are automatically partitioned (or, sharded) across the system and converted to half precision.

这个 Init 类就是 Stage-3 阶段对参数进行分割的核心所在了, 还是先从 __init__ 方法开始。

# Replaces all parameters in module with Scattered Parameters
class Init(InsertPostInitMethodToModuleSubClasses):
    param_id = 0
    param_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "param_persistence_threshold")
    model_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "model_persistence_threshold")
    num_persisted_parameters = 0
    num_persisted_elements = 0
    apply_param_persistence = False
    override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply")

    def __init__(
            self,
            module=None,
            data_parallel_group=None,
            mem_efficient_linear=True,
            remote_device=None,
            pin_memory=False,
            config_dict_or_path=None,
            config=None,
            enabled=True,
            dtype=None,
            mpu=None,
            zero_param_parallel_group=None,
            zero_quantized_weights=False,
            zero_quantized_nontrainable_weights=False,
    ):

        ...
        # 当前进程应该部署的设备
        self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
        get_accelerator().set_device(self.local_device)


        if module is not None:
            assert isinstance(module, torch.nn.Module)
            # 参数分割在这里
            self._convert_to_zero_parameters(module.parameters(recurse=True))
        ...

4.3.1. _convert_to_deepspeed_param

从这里开始才是进入到参数分割的过程,

def _convert_to_zero_parameters(self, param_list):
    for param in param_list:
        # 跳过已经被处理过的参数
        if is_zero_param(param):
            continue
        # 首先把参数数据移到目标设备
        param.data = param.data.to(self.local_device)
        # 逐一处理每个参数
        self._zero_init_param(param)

def _zero_init_param(self, param):
    # 首先把原来的 torch.parameter 对象的参数进行扩展
    # 这里的扩展就是增加必要的属性和方法
    self._convert_to_deepspeed_param(param)

    if dist.get_world_group() == self.get_dp_process_group():
        dist.broadcast(param, 0, self.get_dp_process_group())
    else:
        dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group())
    # 上面 self._convert_to_deepspeed_param(param) 为param 扩展了 partition() 方法
    # 这里调用 partition()方法进行参数分割
    param.partition()

def _convert_to_deepspeed_param(self, param):
    """
    这个方法为原始的 param 实例对象扩展一些必要的属性和方法。
    这里代码真实一言难尽。。。
    """

    # Partitioned, Normal, Remote
    param.ds_param_type = ZeroParamType.PARTITIONED

    # Replicated vs Partitioned vs Inflight
    param.ds_status = ZeroParamStatus.AVAILABLE

    # Stores the shape of the original tensor
    param.ds_shape = param.shape

    # Stores the number of elements in the original parameter without padding
    param.ds_numel = param.numel()

    # Stores the partitioned copy of the tensor
    # 保存分割后的参数向量,只是原来参数向量的一部分
    param.ds_tensor = None

    # Keeps track of how many active sub-modules need this param at any given point in time
    param.ds_active_sub_modules = set()

    # If this flag is true, then the parameters are replicated throughput training
    # And only partitioned before the step
    if Init.apply_param_persistence and param.ds_numel <= Init.param_persistence_threshold and Init.num_persisted_elements + param.ds_numel <= Init.model_persistence_threshold:
        param.ds_persist = True
        Init.num_persisted_parameters += 1
        Init.num_persisted_elements += param.ds_numel
    else:
        param.ds_persist = False

    param.is_external_param = False

    # The group that the parameter is scattered across.
    param.ds_process_group = self.ds_process_group

    # Stores the secondary partitioned copy of the tensor
    param.ds_secondary_tensor = None

    # Process group for secondary partition all (group) gather
    param.ds_zero_param_process_group = self.zero_param_process_group
    param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group
    param.ds_secondary_tensor_num_of_groups = self.num_param_groups

    # This is set to the Async Param swapper if remote device is nvme
    # else this is set to None
    param.nvme_swapper = self.param_swapper

    # DeepSpeed Param ID
    param.ds_id = Init.param_id
    Init.param_id += 1

    def all_gather(param_list=None, async_op=False, hierarchy=0):
        cls = param
        if param_list is None:
            param_list = [cls]
        return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)

    @instrument_w_nvtx
    def all_gather_coalesced(params: Iterable[Parameter],
                             forward: bool = True,
                             safe_mode: bool = False,
                             quantize: bool = False) -> AllGatherCoalescedHandle:
        """这个方法是用来恢复还原这个参数的,稍后会详细讨论"""
        ...
        # 这个方法内容太多,暂时省略掉,在后面的前向计算过程再讨论。

    def partition(param_list=None, backward=False, hierarchy=0, has_been_updated=False):
        cls = param
        print_rank_0(f"{'--' * hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
                     force=False)
        if param_list is None:
            param_list = [cls]
        self._partition(param_list, has_been_updated=has_been_updated)

    def reduce_gradients_at_owner(param_list=None, hierarchy=0):
        cls = param
        if param_list is None:
            param_list = [cls]
        print_rank_0(
            f"{'--' * hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner"
        )
        self._reduce_scatter_gradients(param_list)

    def partition_gradients(param_list=None, partition_buffers=None, hierarchy=0, accumulate=False):
        cls = param
        print_rank_0(
            f"{'--' * hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}")
        if param_list is None:
            param_list = [cls]
            if isinstance(partition_buffers, torch.Tensor):
                partition_buffers = [partition_buffers]

        self._partition_gradients(param_list, partition_buffers=partition_buffers, accumulate=accumulate)

    def aligned_size():
        return self._aligned_size(param)

    def padding_size():
        return self._padding_size(param)

    def partition_numel():
        return self._partition_numel(param)

    def item_override():
        param.all_gather()
        return param._orig_item()

    def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
        return {
            "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
            "status": slf.ds_status.name,
            "numel": slf.numel(),
            "ds_numel": slf.ds_numel,
            "shape": tuple(slf.shape),
            "ds_shape": tuple(slf.ds_shape),
            "requires_grad": slf.requires_grad,
            "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None,
            "persist": slf.ds_persist,
            "active_sub_modules": slf.ds_active_sub_modules,
            "ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None
        }

    def convert_to_zero_parameters(param_list):
        self._convert_to_zero_parameters(param_list)

    def allgather_before(func: Callable) -> Callable:

        def wrapped(*args, **kwargs):
            param.all_gather()
            return func(*args, **kwargs)

        return wrapped

    # Collectives for gathering and partitioning parameters
    param.all_gather = all_gather
    param.all_gather_coalesced = all_gather_coalesced
    param.partition = partition

    # Collective for averaging gradients
    param.reduce_gradients_at_owner = reduce_gradients_at_owner
    param.partition_gradients = partition_gradients

    # Partitioning size utilities
    param.aligned_size = aligned_size
    param.padding_size = padding_size
    param.partition_numel = partition_numel
    param.ds_summary = types.MethodType(ds_summary, param)

    param.item = allgather_before(param.item)

    param.convert_to_zero_parameters = convert_to_zero_parameters

上述代码中,为 param 赋予了一个方法 partition 方法,方法实现如下所示, 这个方法接收一个参数列表,并跳转到 Init::_partition 进行参数分割处理。

def partition(param_list=None, backward=False, hierarchy=0, has_been_updated=False):
    cls = param
    print_rank_0(f"{'--' * hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}",
                 force=False)
    if param_list is None:
        param_list = [cls]
    self._partition(param_list, has_been_updated=has_been_updated)

4.3.2. partition

再次吐槽,这代码写的,生怕别人看懂。

self.remote_device = self.local_device if remote_device in [None, OffloadDeviceEnum.none] else remote_device
# Enable fp16 param swapping to NVMe
#  self.dtype 根据配置文件,self.dtype 有可能是半精度
if self.remote_device == OffloadDeviceEnum.nvme:
    self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
else:
    self.param_swapper = None

先把需要的一些变量展示处理,接下来终于进入参数分割逻辑了,真不容易。。。

def _partition_param(self, param, buffer=None, has_been_updated=False):

    assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight"
    global reuse_buffers
    print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False)
    if param.ds_status is ZeroParamStatus.AVAILABLE:
        print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False)
        # 这个参数已经被分割过了
        if param.ds_tensor is not None and not has_been_updated:  ##param already partitioned
            see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
            # param.data does not store anything meaningful in partitioned state
            # 释放原参数的向量
            free_param(param)
            see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)

            if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
                print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False)
                param.nvme_swapper.remove_partition_and_release_buffers([param])
                print_rank_0(
                    f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme",
                    force=False)

            return
        # 需要把参数张量切分成 self.num_partitions 份
        # 如果参数张量的元素数量不是 self.num_partitions  的整数倍,需要 padding 对齐
        tensor_size = self._aligned_size(param)
        # self.num_partitions == self.dp_world_size == dist.get_world_size(group=self.ds_process_group)
        # self.ds_process_group 默认是全体进程组成的默认组,亦可以通过入参 data_parallel_group 指定
        partition_size = tensor_size // self.num_partitions  # 每一份的尺寸
        if param.ds_tensor is None:
            # 记录参数向量最终被卸载到哪个设备
            final_location = None
            # self.remote_device 表示参数卸载的最终设备
            # self.param_swapper 是和 nvme 设备交换的工具类,其内部已经按照 self.dtype 设置数据类型
            if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor(
                    numel=partition_size):
                # 如果卸载到 nvme,并且nvme还有足够的缓存空间,可以把 partitioned_tensor 直接创建在nvme缓冲区
                final_location = OffloadDeviceEnum.nvme
                buffer = self.param_swapper.get_buffer(param, partition_size)
                # 注意这里新的参数和原参数数据类型保持一致
                partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device)
                partitioned_tensor.data = buffer.data
                print_rank_0(f"ID {param.ds_id} Initializing partition for the first time for nvme offload.")

            else: # 不卸载到 nvme 或者 nvme 没有空间了
                if param.ds_persist:  # ds_persist 表示当前参数不进行卸载(Offload)
                    device = self.local_device  # 保留在原设备中
                elif self.remote_device == OffloadDeviceEnum.nvme:
                    # 这里应为nvme没有空间了,即使最终要卸载到 nvme,这里暂时用 cpu 内存进行缓冲
                    device = OffloadDeviceEnum.cpu
                else:
                    device = self.remote_device
                # 先创建一个空的向量 device = gpu_x or cpu
                partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=device)
                # quantize the tensor if it's not trainable
                # 如果参数不需要求梯度,并且启动了量化功能,就对这个参数进行量化
                if not param.requires_grad and self.quantized_nontrainable_weights:
                    partitioned_tensor, partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize(
                        partitioned_tensor)

                if device == OffloadDeviceEnum.cpu and self.pin_memory:
                    # 启动 pin_memory
                    partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor)
            #
            partitioned_tensor.requires_grad = False  # 分割后参数向量只是暂存,不需要求梯度
            param.ds_tensor = partitioned_tensor  # 这个属性存储分割后参数片段
            param.ds_tensor.ds_numel = partition_size
            param.ds_tensor.status = PartitionedParamStatus.AVAILABLE  # 状态标记好
            param.ds_tensor.final_location = final_location  # 记录下这个片段存储的设备
        # 根据rank编号,判断归属于当前进程的片段起始位置
        start = partition_size * self.get_partition_rank()
        end = start + partition_size
        # 把张量打平成一维连续数组
        one_dim_param = param.contiguous().view(-1)

        if start < param.ds_numel and end <= param.ds_numel:
            # 取出片段,并复制到 param.ds_tensor
            src_tensor = one_dim_param.narrow(0, start, partition_size)
            param.ds_tensor.copy_(src_tensor)
        else:
            if start < param.ds_numel:
                elements_to_copy = param.ds_numel - start
                param.ds_tensor.narrow(0, 0,
                                       elements_to_copy).copy_(one_dim_param.narrow(0, start, elements_to_copy))

        see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False)
        # 释放原来的参数张量
        free_param(param)
        see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False)
        # 如果最终是要卸载到 nvme,这里需要用 self.param_swapper 交换出去
        if param.ds_tensor.final_location == OffloadDeviceEnum.nvme:
            self.param_swapper.swap_out_and_release([param])
            print_rank_0(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.")
            see_memory_usage(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.", force=False)

        print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}")

参数分割逻辑其实不复杂,就是细节很多,每个 rank 只保留参数的一个片段, 并且这个片段还可以存储在 cpu 内存或则 nvme 高速SSD中,这样能最大限度降低对 GPU 显存的需求。

弄清楚参数分割的过程后,接下来就是看在前后向过程是如何还原参数的, 可以跳到章节 节 5 进行跟踪。

4.3.3. partition_param_sec

待补充。。。