@@ -3,6 +3,7 @@ import torch.distributed
from ..config.config import CrossRegionDataParallelConfig
from ..comms.global_process_group_manager import GlobalProcessGroupManager
from ..comms.all_reduce_hook import all_reduce_hook
from .cross_region_params_manager import _ParamAndGradBuffer
DTYPE_MAP = {
@@ -22,7 +23,9 @@ class CrossRegionDataParallel:
self.module = module
self.global_process_group_manager = global_process_group_manager
self.comms_hook = self.defualt_comms_hook
self.comms_hook_state = None
self.comms_hook = None
self._init_default_comms_hook()
self.buffer = None
@@ -40,18 +43,17 @@ class CrossRegionDataParallel:
bucket_size=self.config.bucket_size
)
def defu alt_comms_hook(self):
self.comms_hook = self.global_process_group_manager.all_reduce
def _init_ defau lt_comms_hook(self):
self.register_comms_hook(None, all_reduce_hook)
def _make_comms_hook(self, comms_hook):
global_pg = self.global_process_group_manager.get_gobal_pg()
def wrapper(tensor, async_op):
return comms_hook(tensor, global_pg, async_op)
def _make_comms_hook(self, hook):
def wrapper(bucket):
return hook(self.comms_hook_state, bucket, self.global_process_group_manager.get_gobal_pg())
return wrapper
def register_comms_hook(self, comms_ hook):
comms_hook = self._make_comms_hook(comms_hook)
self.comms_hook = comms_hook
def register_comms_hook(self, state, hook):
self.comms_hook_state = state
self.comms_hook = self._make_ comms_hook(hook)
def start_grad_sync(self, async_op: bool = False):
for bucket in self.buffer.get_buckets():
@@ -61,13 +63,15 @@ class CrossRegionDataParallel:
for bucket in self.buffer.get_buckets():
bucket.finish_grad_sync()
def bradcast_cross_region_params(self):
def bro adcast_cross_region_params(self):
global_pg = self.global_process_group_manager.get_gobal_pg()
opts = torch.distributed.Barrier Options()
opts = torch.distributed.Broadcast Options()
opts.rootRank = 0
opts.rootTensor = 0
opts.asyncOp = False
for param in self.module.parameters():
global_pg.broadcast([param.data], opts)
global_pg.broadcast([param.data], opts).wait()
def caculate_cross_region_grads(self):
self.buffer.caculate_grads()