5. Stage3 - hook 注册¶
参数分割之后,在执行前向、后向之前,需要先把参数再还原回来。
同理,在执行前向后向之后,还要释放掉各自不需要的参数。
这里利用 pytorch
的 hook
功能在上述四个关键节点插入相关的动作。
pytorch
的 Module
类型提供了一系列 register_xxx_hook
方法来实现 hook
功能。
deepspeed
的 hook
动作都在类 DeepSpeedZeRoOffload
中实现,
具体的在方法 DeepSpeedZeRoOffload::setup_zero_stage3_hooks
中。
def setup_zero_stage3_hooks(self):
"""
注册 stage3 相关的hook函数
Returns:
"""
self.hierarchy = 0
#reset step if in inference mode
@instrument_w_nvtx
def _end_of_forward_hook(module, *args):
if not torch._C.is_grad_enabled():
self.get_param_coordinator(training=False).reset_step()
#likely one of them should be enough but just to be safe
# 注册各种 钩子 hook ,
# 包括 pre_forward、pre_backward、post_forward、post_backward
self._register_hooks_recursively(self.module)
self.module.register_forward_hook(_end_of_forward_hook)
# Add top module to stack trace
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(self.module)
def _register_hooks_recursively(self, module, count=[0]):
"""真正执行hook操作"""
my_count = count[0]
module.id = my_count
#print(f"{module.__class__} : {module.id}")
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
self.pre_sub_module_forward_function(module)
@instrument_w_nvtx
def _post_forward_module_hook(module, input, output):
global FWD_MODULE_STACK
FWD_MODULE_STACK.pop()
if output is None:
output = []
elif not isinstance(output, (list, tuple)):
if torch.is_tensor(output):
output = [output]
else:
#print(f'got UNKNOWN type {type(output)}')
outputs = []
output = output if isinstance(output, dict) else vars(output)
for name, val in output.items():
if not name.startswith('__') and torch.is_tensor(val):
outputs.append(val)
output = outputs
for item in filter(lambda item: is_zero_param(item) or hasattr(item, 'ds_param_alias'), output):
key = id(item) if hasattr(item, 'ds_id') else id(item.ds_param_alias)
actual_external_param = item if hasattr(item, 'ds_id') else item.ds_param_alias
if not any(key in m._external_params for m in FWD_MODULE_STACK):
actual_external_param.is_external_param = True
module_to_register = FWD_MODULE_STACK[-1]
register_external_parameter(module_to_register, actual_external_param)
print_rank_0(
f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {actual_external_param.ds_id}.',
force=False)
# It's possible that the parameter was already external to the completed module. If so, remove it the
# registration as it will be covered by the outer module instead.
if key in module._external_params:
print_rank_0(
f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {actual_external_param.ds_id}',
force=False)
unregister_external_parameter(module, actual_external_param)
actual_external_param.all_gather()
self.post_sub_module_forward_function(module)
def _pre_backward_module_hook(module, inputs, output):
@instrument_w_nvtx
def _run_before_backward_function(sub_module):
# some models (e.g. Albert) may run multiple forwards on the same layer in a loop
# before doing backwards, so each backward will need a pre-fetch - using reference
# counting to support this scenario
#print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}")
if sub_module.applied_pre_backward_ref_cnt > 0:
self.pre_sub_module_backward_function(sub_module)
sub_module.applied_pre_backward_ref_cnt -= 1
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
return _apply_to_tensors_only(module, PreBackwardFunction, _run_before_backward_function, output)
#This is an alternate to doing _post_backward_module_hook
#it uses tensor.register_hook instead of using torch.autograd.Function
def _alternate_post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
#print(f"Before Forward {module.__class__.__name__}")
def _run_after_backward_hook(*unused):
module.ds_grads_remaining = module.ds_grads_remaining - 1
if module.ds_grads_remaining == 0:
#print(f"After backward {module.__class__.__name__}")
self.post_sub_module_backward_function(module)
def _run_before_forward_function(input):
if input.requires_grad:
module.ds_grads_remaining += 1
return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function,
_run_after_backward_hook, inputs)
def _post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
@instrument_w_nvtx
def _run_after_backward_function(sub_module):
if sub_module.ds_grads_remaining == 0:
self.post_sub_module_backward_function(sub_module)
return _apply_to_tensors_only(module, PostBackwardFunction, _run_after_backward_function, inputs)
# Pre forward hook
self.forward_hooks.append(module.register_forward_pre_hook(_pre_forward_module_hook))
# Post forward hook
self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook))
# Pre backward hook
self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook))
# post backward hook
self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook))
前向过程之前 pre_forward
显然,在执行前向过程之前,我们需要 把被分割的参数还原回来,这里自然通过 AllGather
通信还原本层的参数。
这里通过 module.register_forward_pre_hook(_pre_forward_module_hook)
进行注册,
顺着 _pre_forward_module_hook
跟踪下去:
@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)
global FWD_MODULE_STACK
FWD_MODULE_STACK.append(sub_module)
param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
# 真正的参数聚合动作是在这里
# param_coordinator 的类型是 deepspeed.runtime.zero.PartitionedParameterCoordinator
param_coordinator.fetch_sub_module(sub_module, forward=True)
see_memory_usage(f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
前向过程之后 post_forward
@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
# 重新释放参数
param_coordinator = self.get_param_coordinator(training=sub_module.training)
# 具体操作都在 deepspeed.runtime.zero.PartitionedParameterCoordinator::release_sub_module
param_coordinator.release_sub_module(sub_module, backward=False)
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
后向过程之前 pre_backward
@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator(training=True)
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.fetch_sub_module(sub_module, forward=False)
后向过程之后 post_backward
@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
self.get_param_coordinator(training=True).release_sub_module(sub_module, backward=True)
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
可以看到每一个节点的实现其实都在
deepspeed.runtime.zero.PartitionedParameterCoordinator
里面,
最终会跳转这个类里去执行。