@@ -56,6 +56,7 @@ __all__ = ['TransformCkpt']
class TransformCkpt:
"""Transform src_checkpoint from src_strategy to dst_strategy."""
def __init__(self,
auto_trans_ckpt: bool = False,
rank_id: Optional[int] = None,
@@ -106,21 +107,25 @@ class TransformCkpt:
self.npu_num_per_node = npu_num_per_node or get_device_num_per_node()
self.node_num = self.world_size // self.npu_num_per_node
if not is_power_of_two(self.npu_num_per_node):
raise ValueError(
f"The `npu_num_per_node` must be a power of 2, but get {npu_num_per_node}")
err_msg = f"The `npu_num_per_node` must be a power of 2, but get {npu_num_per_node}"
logger.error(err_msg)
raise ValueError(err_msg)
# Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size.
if transform_process_num < 1:
raise ValueError("transform_process_num should not smaller than 1,"
f"but got {transform_process_num}.")
err_msg = f"transform_process_num should not smaller than 1, but got {transform_process_num}."
logger.error(err_msg)
raise ValueError(err_msg)
if transform_process_num > self.world_size:
logger.warning("transform_process_num: %d should not bigger than world_size: %d. \
transform_process_num is set to %d.",
transform_process_num, self.world_size, self.world_size)
transform_process_num = self.world_size
if self.world_size % transform_process_num != 0:
raise ValueError(f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
err_msg = (f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
logger.error(err_msg)
raise ValueError(err_msg)
if check_in_modelarts() and 1 < transform_process_num < self.node_num:
logger.warning("transform_process_num: %d should not smaller than \
node_num = world_size // npu_num_per_node = %d when training on AICC. \
@@ -130,13 +135,13 @@ class TransformCkpt:
if check_in_modelarts() and transform_process_num == 1:
# The 0th NPU of each node is responsible for transform all checkpoints.
# For example, if world_size is 16 and npu_num_per_node is 8, then transform_rank_id_list should be [0,8].
self.transform_rank_id_list = [i for i in range(0, self.world_size, self.npu_num_per_node)]
self.transform_rank_id_list = list(range(0, self.world_size, self.npu_num_per_node))
else:
# Obtain transform_rank_id_list. For example,
# if world_size is 8 and transform_process_num is 2, then transform_rank_id_list should be [0,4].
# which means that the 0th rank and the 4th rank responsible for transform checkpoints.
self.transform_rank_id_list = \
[i for i in range(0, self.world_size, self.world_size // transform_process_num)]
self.transform_rank_id_list = list(range(0, self.world_size, self.world_size // transform_process_num))
self.transform_process_num = len(self.transform_rank_id_list)
if auto_trans_ckpt:
@@ -153,10 +158,7 @@ class TransformCkpt:
self.auto_trans_ckpt = auto_trans_ckpt
self.transform_by_rank = transform_by_rank
if transform_process_num > 1:
self.transform_by_rank = True
elif self.world_size == 1:
self.transform_by_rank = False
self.update_transform_by_rank(transform_process_num)
self.cache_list = []
logger.info(f"rank_id: {self.rank_id}")
@@ -164,6 +166,13 @@ class TransformCkpt:
logger.info(f"transform_process_num: {self.transform_process_num}")
logger.info(f"transform_rank_id_list: {self.transform_rank_id_list}")
def update_transform_by_rank(self, transform_process_num):
"""Update transform_by_rank according to transform_process_num and world_size."""
if transform_process_num > 1:
self.transform_by_rank = True
elif self.world_size == 1:
self.transform_by_rank = False
def __call__(self,
src_checkpoint: str,
dst_checkpoint_dir: Optional[str] = None,
@@ -237,23 +246,29 @@ class TransformCkpt:
if self.world_size > 1:
dst_strategy_list = glob(os.path.join(self.dst_strategy_dir, f"*_rank_{self.rank_id}.ckpt"))
if not dst_strategy_list:
raise RuntimeError(f"The `dst_strategy`={self.dst_strategy_dir} \
does not contain strategy file of rank_{self.rank_id}.")
err_msg = (f"The `dst_strategy`={self.dst_strategy_dir} "
f"does not contain strategy file of rank_{self.rank_id}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
if len(dst_strategy_list) > 1:
raise RuntimeError(f"There can only be one strategy file corresponding to rank_{self.rank_id}, \
but multiple strategy files corresponding to rank_{self.rank_id} were found \
in {self.dst_strategy_dir}.")
err_msg = (f"There can only be one strategy file corresponding to rank_{self.rank_id}, "
f"but multiple strategy files corresponding to rank_{self.rank_id} "
f"were found in {self.dst_strategy_dir}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
dst_strategy = dst_strategy_list[0]
else:
dst_strategy = None
if check_in_modelarts():
if not mox.file.exists(self.transformed_checkpoint_dir_obs):
raise ValueError(f"transformed_checkpoint_dir_obs: "
f"{self.transformed_checkpoint_dir_obs} is not found!")
err_log = f"transformed_checkpoint_dir_obs: {self.transformed_checkpoint_dir_obs} is not found!"
logger.error(err_log)
raise ValueError(err_log)
if self.world_size > 1 and not mox.file.exists(self.dst_strategy_dir_obs):
raise ValueError(f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!")
err_log = f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!"
logger.error(err_log)
raise ValueError(err_log)
# Get final dst_strategy in auto_trans_ckpt mode.
dst_strategy = self.get_dst_strategy(dst_strategy)
@@ -272,13 +287,7 @@ class TransformCkpt:
barrier_world(f"Remake {dst_ckpt_dir} by main rank.")
logger.info("The transformed checkpoint will be saved under %s.", dst_ckpt_dir)
self.transform_ckpt(
src_checkpoint=src_ckpt_dir,
dst_checkpoint_dir=dst_ckpt_dir,
src_strategy=src_strategy,
dst_strategy=dst_strategy,
prefix=prefix
)
self.transform_ckpt(src_ckpt_dir, dst_ckpt_dir, src_strategy, dst_strategy, prefix)
self.clear_cache()
return dst_checkpoint_dir
@@ -292,7 +301,9 @@ class TransformCkpt:
"""Transform ckpt using mindspore.transform_checkpoint"""
self.check_src_checkpoint_and_strategy(src_checkpoint, src_strategy)
if src_strategy is None and dst_strategy is None:
raise ValueError("`src_strategy` and `dst_strategy` cannot both be None!")
err_msg = "`src_strategy` and `dst_strategy` cannot both be None!"
logger.error(err_msg)
raise ValueError(err_msg)
if check_in_modelarts():
dst_checkpoint_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs,
os.path.basename(dst_checkpoint_dir))
@@ -364,7 +375,7 @@ class TransformCkpt:
dst_strategy):
"""transform checkpoints using mindspore.transform_checkpoint_by_rank"""
for current_transform_rank_id in \
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
logger.info(".........Transforming Ckpt For Rank: %d.........", current_transform_rank_id)
src_rank_list = ms.rank_list_for_transform(current_transform_rank_id,
src_strategy,
@@ -374,13 +385,15 @@ class TransformCkpt:
checkpoint_rank_dir = os.path.join(src_checkpoint, f"rank_{src_rank_id}")
checkpoint_file_list = glob(os.path.join(checkpoint_rank_dir, "*.ckpt"))
if not checkpoint_file_list:
raise ValueError(f"The checkpoint of rank_{src_rank_id} is not found!")
err_msg = f"The checkpoint of rank_{src_rank_id} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
checkpoint_file_list = sorted(checkpoint_file_list, key=os.path.getmtime)
checkpoint_file_map[src_rank_id] = checkpoint_file_list[-1]
save_checkpoint_dir = os.path.join(dst_checkpoint, "rank_{}".format(current_transform_rank_id) )
save_checkpoint_dir = os.path.join(dst_checkpoint, f"rank_{current_transform_rank_id}" )
os.makedirs(save_checkpoint_dir, exist_ok=True)
save_checkpoint_path = os.path.join(save_checkpoint_dir,
"{}.ckpt".format(prefix + str(current_transform_rank_id)) )
f"{prefix + str(current_transform_rank_id)}.ckpt" )
logger.info("rank_list: %s", src_rank_list)
logger.info("checkpoint_file_map: %s", checkpoint_file_map)
logger.info("save_checkpoint_path: %s", save_checkpoint_path)
@@ -396,11 +409,15 @@ class TransformCkpt:
def build_soft_link_of_checkpoint(checkpoint, soft_link_dir):
"""Build softlink of src checkpoint"""
if os.path.isdir(checkpoint) and not check_rank_folders(checkpoint, 0) and \
not check_ckpt_file_exist(checkpoint):
raise ValueError(f"No rank_0 folder or ckpt files are found under {checkpoint}.")
not check_ckpt_file_exist(checkpoint):
err_msg = f"No rank_0 folder or ckpt files are found under {checkpoint}."
logger.error(err_msg)
raise ValueError(err_msg)
if os.path.isfile(checkpoint) and not checkpoint.endswith('.ckpt'):
raise ValueError(f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
err_msg = ("The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
logger.error(err_msg)
raise ValueError(err_msg)
if os.path.isdir(checkpoint):
if check_rank_folders(checkpoint, 0):
@@ -443,7 +460,9 @@ class TransformCkpt:
return None
if not os.path.exists(strategy_path):
raise ValueError(f'strategy_path: {strategy_path} not found!')
err_msg = f'strategy_path: {strategy_path} not found!'
logger.error(err_msg)
raise ValueError(err_msg)
if os.path.isfile(strategy_path):
return strategy_path
@@ -452,7 +471,7 @@ class TransformCkpt:
if rank_id:
merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy_by_rank_{rank_id}.ckpt')
else:
merge_path = os.path.join(strategy_path, f 'merged_ckpt_strategy.ckpt')
merge_path = os.path.join(strategy_path, 'merged_ckpt_strategy.ckpt')
merged_succeed_txt = os.path.join(strategy_path, "merge_succeed.txt")
if self.is_main_rank:
@@ -479,8 +498,9 @@ class TransformCkpt:
if not (dst_strategy.endswith(f"_rank_{self.rank_id}.ckpt") and
os.path.exists(dst_strategy)):
raise ValueError(f"dst_strategy: {dst_strategy} is not found!")
err_msg = f"dst_strategy: {dst_strategy} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
logger.info(".........Collecting strategy.........")
if check_in_modelarts():
@@ -521,9 +541,11 @@ class TransformCkpt:
"""check src checkpoint and strategy"""
check_path(src_checkpoint, "src_checkpoint")
if not os.path.isdir(src_checkpoint) or not glob(os.path.join(src_checkpoint, "rank_*")):
raise ValueError("The load_checkpoint must be a dir and "
"ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt,"
f"but get {src_checkpoint}.")
err_msg = ("The load_checkpoint must be a dir "
"and ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt, "
f"but get {src_checkpoint}.")
logger.error(err_msg)
raise ValueError(err_msg)
# Check rank_dirs is continuous.
# For example, rank_0, rank_1, rank_4 is not continuous because it is missing rank_3
src_checkpoint_rank_dir_list = glob(os.path.join(src_checkpoint, "rank_*"))
@@ -532,7 +554,9 @@ class TransformCkpt:
src_checkpoint_rank_num = len(src_checkpoint_rank_id_list)
for i in range(src_checkpoint_rank_num):
if src_checkpoint_rank_id_list[i] != i:
raise FileNotFoundError(f"The rank_{i} folder was not found under src_checkpoint folder.")
err_msg = f"The rank_{i} folder was not found under src_checkpoint folder."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
# A full checkpoint do not require a strategy.
if len(src_checkpoint_rank_id_list) == 1 and src_strategy:
@@ -540,7 +564,9 @@ class TransformCkpt:
src_strategy = None
# Distributed checkpoints must be accompanied by strategy.
if len(src_checkpoint_rank_id_list) > 1 and src_strategy is None:
raise ValueError("`src_strategy` should not be None when `src_checkpoint` is sliced.")
err_msg = "`src_strategy` should not be None when `src_checkpoint` is sliced."
logger.error(err_msg)
raise ValueError(err_msg)
def send_strategy_to_obs(self, strategy):
"""Local rank send strategy file to obs."""
@@ -587,7 +613,9 @@ class TransformCkpt:
last_strategy_num = dst_strategy_num
if dst_strategy_num < self.world_size:
if time.time() - start_time > 7200:
raise TimeoutError("Timeout while collecting all strategy!")
err_msg = "Timeout while collecting all strategy!"
logger.error(err_msg)
raise TimeoutError(err_msg)
time.sleep(5)
else:
break
@@ -599,14 +627,16 @@ class TransformCkpt:
if check_in_modelarts():
transformed_ckpt_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs, os.path.basename(ckpt_dir))
transform_failed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs,
f 'transform_failed_rank_*.txt'))
'transform_failed_rank_*.txt'))
transform_succeed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs,
f 'transform_succeed_rank_*.txt'))
'transform_succeed_rank_*.txt'))
else:
transform_failed_txts = glob(os.path.join(ckpt_dir, f 'transform_failed_rank_*.txt'))
transform_succeed_txts = glob(os.path.join(ckpt_dir, f 'transform_succeed_rank_*.txt'))
transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt'))
transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt'))
if transform_failed_txts:
raise ValueError(f"Transform failed, find {transform_failed_txts}.")
err_msg = f"Transform failed, find {transform_failed_txts}."
logger.error(err_msg)
raise ValueError(err_msg)
current_count = len(transform_succeed_txts)
progress = (current_count / self.transform_process_num) * 100
if current_count != last_count:
@@ -617,6 +647,7 @@ class TransformCkpt:
else:
break
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--src_checkpoint',