5. Stage3 - hook 注册

参数分割之后,在执行前向、后向之前,需要先把参数再还原回来。 同理,在执行前向后向之后,还要释放掉各自不需要的参数。 这里利用 pytorchhook 功能在上述四个关键节点插入相关的动作。 pytorchModule 类型提供了一系列 register_xxx_hook 方法来实现 hook 功能。

deepspeedhook 动作都在类 DeepSpeedZeRoOffload 中实现, 具体的在方法 DeepSpeedZeRoOffload::setup_zero_stage3_hooks 中。

def setup_zero_stage3_hooks(self):
    """
    注册 stage3 相关的hook函数
    Returns:

    """
    self.hierarchy = 0

    #reset step if in inference mode
    @instrument_w_nvtx
    def _end_of_forward_hook(module, *args):

        if not torch._C.is_grad_enabled():
            self.get_param_coordinator(training=False).reset_step()

    #likely one of them should be enough but just to be safe
    # 注册各种 钩子 hook ,
    # 包括 pre_forward、pre_backward、post_forward、post_backward
    self._register_hooks_recursively(self.module)
    self.module.register_forward_hook(_end_of_forward_hook)

    # Add top module to stack trace
    global FWD_MODULE_STACK
    FWD_MODULE_STACK.append(self.module)

def _register_hooks_recursively(self, module, count=[0]):
    """真正执行hook操作"""
    my_count = count[0]
    module.id = my_count

    #print(f"{module.__class__} : {module.id}")

    for child in module.children():
        count[0] = count[0] + 1
        self._register_hooks_recursively(child, count=count)

    @instrument_w_nvtx
    def _pre_forward_module_hook(module, *args):
        self.pre_sub_module_forward_function(module)

    @instrument_w_nvtx
    def _post_forward_module_hook(module, input, output):
        global FWD_MODULE_STACK
        FWD_MODULE_STACK.pop()
        if output is None:
            output = []
        elif not isinstance(output, (list, tuple)):
            if torch.is_tensor(output):
                output = [output]
            else:
                #print(f'got UNKNOWN type {type(output)}')
                outputs = []
                output = output if isinstance(output, dict) else vars(output)
                for name, val in output.items():
                    if not name.startswith('__') and torch.is_tensor(val):
                        outputs.append(val)
                output = outputs

        for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output):
            key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias)
            actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias

            if not any(key in m._external_params for m in FWD_MODULE_STACK):
                actual_external_param.is_external_param = True
                module_to_register = FWD_MODULE_STACK[-1]
                register_external_parameter(module_to_register, actual_external_param)
                print_rank_0(
                    f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.',
                    force=False)

                # It's possible that the parameter was already external to the completed module. If so, remove it the
                # registration as it will be covered by the outer module instead.
                if key in module._external_params:
                    print_rank_0(
                        f'  Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}',
                        force=False)
                    unregister_external_parameter(module, actual_external_param)

                actual_external_param.all_gather()

        self.post_sub_module_forward_function(module)

    def _pre_backward_module_hook(module, inputs, output):

        @instrument_w_nvtx
        def _run_before_backward_function(sub_module):
            # some models (e.g. Albert) may run multiple forwards on the same layer in a loop
            # before doing backwards, so each backward will need a pre-fetch - using reference
            # counting to support this scenario
            #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
            if sub_module.applied_pre_backward_ref_cnt > 0:
                self.pre_sub_module_backward_function(sub_module)
                sub_module.applied_pre_backward_ref_cnt -= 1
            #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")

        return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output)

    #This is an alternate to doing _post_backward_module_hook
    #it uses tensor.register_hook instead of using torch.autograd.Function
    def _alternate_post_backward_module_hook(module, inputs):
        module.ds_grads_remaining = 0

        #print(f"Before Forward {module.__class__.__name__}")

        def _run_after_backward_hook(*unused):
            module.ds_grads_remaining = module.ds_grads_remaining - 1
            if module.ds_grads_remaining == 0:
                #print(f"After backward {module.__class__.__name__}")
                self.post_sub_module_backward_function(module)

        def _run_before_forward_function(input):
            if input.requires_grad:
                module.ds_grads_remaining += 1

        return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function,
                                                           _run_after_backward_hook, inputs)

    def _post_backward_module_hook(module, inputs):
        module.ds_grads_remaining = 0

        @instrument_w_nvtx
        def _run_after_backward_function(sub_module):
            if sub_module.ds_grads_remaining == 0:
                self.post_sub_module_backward_function(sub_module)

        return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs)

    # Pre forward hook
    self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))

    # Post forward hook
    self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))

    # Pre backward hook
    self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))

    # post backward hook
    self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))

前向过程之前 pre_forward

显然,在执行前向过程之前,我们需要 把被分割的参数还原回来,这里自然通过 AllGather 通信还原本层的参数。 这里通过 module.register_forward_pre_hook(_pre_forward_module_hook) 进行注册, 顺着 _pre_forward_module_hook 跟踪下去:

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
    see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)

    global FWD_MODULE_STACK
    FWD_MODULE_STACK.append(sub_module)

    param_coordinator = self.get_param_coordinator(training=sub_module.training)
    param_coordinator.trace_prologue(sub_module)
    if param_coordinator.is_record_trace():
        param_coordinator.record_module(sub_module)
    # 真正的参数聚合动作是在这里
    # param_coordinator 的类型是 deepspeed.runtime.zero.PartitionedParameterCoordinator
    param_coordinator.fetch_sub_module(sub_module, forward=True)

    see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)

前向过程之后 post_forward

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
    see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
                     force=False)
    # 重新释放参数
    param_coordinator = self.get_param_coordinator(training=sub_module.training)
    # 具体操作都在 deepspeed.runtime.zero.PartitionedParameterCoordinator::release_sub_module
    param_coordinator.release_sub_module(sub_module, backward=False)

    see_memory_usage(f"After sub module function {sub_module.__class__.__name__}  {sub_module.id} after release",
                     force=False)

后向过程之前 pre_backward

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
    assert sub_module.training, "backward pass is invalid for module in evaluation mode"
    param_coordinator = self.get_param_coordinator(training=True)
    param_coordinator.trace_prologue(sub_module)
    if param_coordinator.is_record_trace():
        param_coordinator.record_module(sub_module)
    param_coordinator.fetch_sub_module(sub_module, forward=False)

后向过程之后 post_backward

@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
    assert sub_module.training, "backward pass is invalid for module in evaluation mode"
    see_memory_usage(
        f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
        force=False)

    self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True)

    see_memory_usage(
        f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
        force=False)

可以看到每一个节点的实现其实都在 deepspeed.runtime.zero.PartitionedParameterCoordinator 里面, 最终会跳转这个类里去执行。