2 Commits

6 changed files with 235 additions and 115 deletions
Split View
  1. +84
    -53
      mindformers/tools/ckpt_transform/transform_checkpoint.py
  2. +39
    -18
      mindformers/tools/resume_ckpt.py
  3. +11
    -6
      mindformers/trainer/base_trainer.py
  4. +42
    -15
      mindformers/trainer/utils.py
  5. +56
    -22
      mindformers/utils/load_checkpoint_utils.py
  6. +3
    -1
      mindformers/utils/resume_ckpt_utils.py

+ 84
- 53
mindformers/tools/ckpt_transform/transform_checkpoint.py View File

@@ -56,6 +56,7 @@ __all__ = ['TransformCkpt']

class TransformCkpt:
"""Transform src_checkpoint from src_strategy to dst_strategy."""

def __init__(self,
auto_trans_ckpt: bool = False,
rank_id: Optional[int] = None,
@@ -106,21 +107,25 @@ class TransformCkpt:
self.npu_num_per_node = npu_num_per_node or get_device_num_per_node()
self.node_num = self.world_size // self.npu_num_per_node
if not is_power_of_two(self.npu_num_per_node):
raise ValueError(
f"The `npu_num_per_node` must be a power of 2, but get {npu_num_per_node}")
err_msg = f"The `npu_num_per_node` must be a power of 2, but get {npu_num_per_node}"
logger.error(err_msg)
raise ValueError(err_msg)

# Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size.
if transform_process_num < 1:
raise ValueError("transform_process_num should not smaller than 1,"
f"but got {transform_process_num}.")
err_msg = f"transform_process_num should not smaller than 1, but got {transform_process_num}."
logger.error(err_msg)
raise ValueError(err_msg)
if transform_process_num > self.world_size:
logger.warning("transform_process_num: %d should not bigger than world_size: %d. \
transform_process_num is set to %d.",
transform_process_num, self.world_size, self.world_size)
transform_process_num = self.world_size
if self.world_size % transform_process_num != 0:
raise ValueError(f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
err_msg = (f"transform_process_num: {transform_process_num} "
f"should be divided by world_size: {self.world_size}.")
logger.error(err_msg)
raise ValueError(err_msg)
if check_in_modelarts() and 1 < transform_process_num < self.node_num:
logger.warning("transform_process_num: %d should not smaller than \
node_num = world_size // npu_num_per_node = %d when training on AICC. \
@@ -130,13 +135,13 @@ class TransformCkpt:
if check_in_modelarts() and transform_process_num == 1:
# The 0th NPU of each node is responsible for transform all checkpoints.
# For example, if world_size is 16 and npu_num_per_node is 8, then transform_rank_id_list should be [0,8].
self.transform_rank_id_list = [i for i in range(0, self.world_size, self.npu_num_per_node)]
self.transform_rank_id_list = list(range(0, self.world_size, self.npu_num_per_node))
else:
# Obtain transform_rank_id_list. For example,
# if world_size is 8 and transform_process_num is 2, then transform_rank_id_list should be [0,4].
# which means that the 0th rank and the 4th rank responsible for transform checkpoints.
self.transform_rank_id_list = \
[i for i in range(0, self.world_size, self.world_size // transform_process_num)]
self.transform_rank_id_list = list(range(0, self.world_size, self.world_size // transform_process_num))
self.transform_process_num = len(self.transform_rank_id_list)

if auto_trans_ckpt:
@@ -153,10 +158,7 @@ class TransformCkpt:
self.auto_trans_ckpt = auto_trans_ckpt

self.transform_by_rank = transform_by_rank
if transform_process_num > 1:
self.transform_by_rank = True
elif self.world_size == 1:
self.transform_by_rank = False
self.update_transform_by_rank(transform_process_num)

self.cache_list = []
logger.info(f"rank_id: {self.rank_id}")
@@ -164,6 +166,13 @@ class TransformCkpt:
logger.info(f"transform_process_num: {self.transform_process_num}")
logger.info(f"transform_rank_id_list: {self.transform_rank_id_list}")

def update_transform_by_rank(self, transform_process_num):
"""Update transform_by_rank according to transform_process_num and world_size."""
if transform_process_num > 1:
self.transform_by_rank = True
elif self.world_size == 1:
self.transform_by_rank = False

def __call__(self,
src_checkpoint: str,
dst_checkpoint_dir: Optional[str] = None,
@@ -237,23 +246,29 @@ class TransformCkpt:
if self.world_size > 1:
dst_strategy_list = glob(os.path.join(self.dst_strategy_dir, f"*_rank_{self.rank_id}.ckpt"))
if not dst_strategy_list:
raise RuntimeError(f"The `dst_strategy`={self.dst_strategy_dir} \
does not contain strategy file of rank_{self.rank_id}.")
err_msg = (f"The `dst_strategy`={self.dst_strategy_dir} "
f"does not contain strategy file of rank_{self.rank_id}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
if len(dst_strategy_list) > 1:
raise RuntimeError(f"There can only be one strategy file corresponding to rank_{self.rank_id}, \
but multiple strategy files corresponding to rank_{self.rank_id} were found \
in {self.dst_strategy_dir}.")
err_msg = (f"There can only be one strategy file corresponding to rank_{self.rank_id}, "
f"but multiple strategy files corresponding to rank_{self.rank_id} "
f"were found in {self.dst_strategy_dir}.")
logger.error(err_msg)
raise RuntimeError(err_msg)
dst_strategy = dst_strategy_list[0]
else:
dst_strategy = None

if check_in_modelarts():
if not mox.file.exists(self.transformed_checkpoint_dir_obs):
raise ValueError(f"transformed_checkpoint_dir_obs: "
f"{self.transformed_checkpoint_dir_obs} is not found!")
err_log = f"transformed_checkpoint_dir_obs: {self.transformed_checkpoint_dir_obs} is not found!"
logger.error(err_log)
raise ValueError(err_log)
if self.world_size > 1 and not mox.file.exists(self.dst_strategy_dir_obs):
raise ValueError(f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!")

err_log = f"dst_strategy_dir_obs: {self.dst_strategy_dir_obs} is not found!"
logger.error(err_log)
raise ValueError(err_log)

# Get final dst_strategy in auto_trans_ckpt mode.
dst_strategy = self.get_dst_strategy(dst_strategy)
@@ -272,13 +287,7 @@ class TransformCkpt:
barrier_world(f"Remake {dst_ckpt_dir} by main rank.")

logger.info("The transformed checkpoint will be saved under %s.", dst_ckpt_dir)
self.transform_ckpt(
src_checkpoint=src_ckpt_dir,
dst_checkpoint_dir=dst_ckpt_dir,
src_strategy=src_strategy,
dst_strategy=dst_strategy,
prefix=prefix
)
self.transform_ckpt(src_ckpt_dir, dst_ckpt_dir, src_strategy, dst_strategy, prefix)

self.clear_cache()
return dst_checkpoint_dir
@@ -292,7 +301,9 @@ class TransformCkpt:
"""Transform ckpt using mindspore.transform_checkpoint"""
self.check_src_checkpoint_and_strategy(src_checkpoint, src_strategy)
if src_strategy is None and dst_strategy is None:
raise ValueError("`src_strategy` and `dst_strategy` cannot both be None!")
err_msg = "`src_strategy` and `dst_strategy` cannot both be None!"
logger.error(err_msg)
raise ValueError(err_msg)
if check_in_modelarts():
dst_checkpoint_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs,
os.path.basename(dst_checkpoint_dir))
@@ -364,7 +375,7 @@ class TransformCkpt:
dst_strategy):
"""transform checkpoints using mindspore.transform_checkpoint_by_rank"""
for current_transform_rank_id in \
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
range(self.rank_id, self.rank_id + self.world_size // self.transform_process_num):
logger.info(".........Transforming Ckpt For Rank: %d.........", current_transform_rank_id)
src_rank_list = ms.rank_list_for_transform(current_transform_rank_id,
src_strategy,
@@ -374,13 +385,15 @@ class TransformCkpt:
checkpoint_rank_dir = os.path.join(src_checkpoint, f"rank_{src_rank_id}")
checkpoint_file_list = glob(os.path.join(checkpoint_rank_dir, "*.ckpt"))
if not checkpoint_file_list:
raise ValueError(f"The checkpoint of rank_{src_rank_id} is not found!")
err_msg = f"The checkpoint of rank_{src_rank_id} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)
checkpoint_file_list = sorted(checkpoint_file_list, key=os.path.getmtime)
checkpoint_file_map[src_rank_id] = checkpoint_file_list[-1]
save_checkpoint_dir = os.path.join(dst_checkpoint, "rank_{}".format(current_transform_rank_id))
save_checkpoint_dir = os.path.join(dst_checkpoint, f"rank_{current_transform_rank_id}")
os.makedirs(save_checkpoint_dir, exist_ok=True)
save_checkpoint_path = os.path.join(save_checkpoint_dir,
"{}.ckpt".format(prefix + str(current_transform_rank_id)))
f"{prefix + str(current_transform_rank_id)}.ckpt")
logger.info("rank_list: %s", src_rank_list)
logger.info("checkpoint_file_map: %s", checkpoint_file_map)
logger.info("save_checkpoint_path: %s", save_checkpoint_path)
@@ -396,11 +409,15 @@ class TransformCkpt:
def build_soft_link_of_checkpoint(checkpoint, soft_link_dir):
"""Build softlink of src checkpoint"""
if os.path.isdir(checkpoint) and not check_rank_folders(checkpoint, 0) and \
not check_ckpt_file_exist(checkpoint):
raise ValueError(f"No rank_0 folder or ckpt files are found under {checkpoint}.")
not check_ckpt_file_exist(checkpoint):
err_msg = f"No rank_0 folder or ckpt files are found under {checkpoint}."
logger.error(err_msg)
raise ValueError(err_msg)
if os.path.isfile(checkpoint) and not checkpoint.endswith('.ckpt'):
raise ValueError(f"The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
err_msg = ("The value of load_checkpoint must be a folder or a file with suffix '.ckpt', "
f"but got {checkpoint}")
logger.error(err_msg)
raise ValueError(err_msg)

if os.path.isdir(checkpoint):
if check_rank_folders(checkpoint, 0):
@@ -443,7 +460,9 @@ class TransformCkpt:
return None

if not os.path.exists(strategy_path):
raise ValueError(f'strategy_path: {strategy_path} not found!')
err_msg = f'strategy_path: {strategy_path} not found!'
logger.error(err_msg)
raise ValueError(err_msg)

if os.path.isfile(strategy_path):
return strategy_path
@@ -452,7 +471,7 @@ class TransformCkpt:
if rank_id:
merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy_by_rank_{rank_id}.ckpt')
else:
merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy.ckpt')
merge_path = os.path.join(strategy_path, 'merged_ckpt_strategy.ckpt')

merged_succeed_txt = os.path.join(strategy_path, "merge_succeed.txt")
if self.is_main_rank:
@@ -479,8 +498,9 @@ class TransformCkpt:

if not (dst_strategy.endswith(f"_rank_{self.rank_id}.ckpt") and
os.path.exists(dst_strategy)):
raise ValueError(f"dst_strategy: {dst_strategy} is not found!")

err_msg = f"dst_strategy: {dst_strategy} is not found!"
logger.error(err_msg)
raise ValueError(err_msg)

logger.info(".........Collecting strategy.........")
if check_in_modelarts():
@@ -521,9 +541,11 @@ class TransformCkpt:
"""check src checkpoint and strategy"""
check_path(src_checkpoint, "src_checkpoint")
if not os.path.isdir(src_checkpoint) or not glob(os.path.join(src_checkpoint, "rank_*")):
raise ValueError("The load_checkpoint must be a dir and "
"ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt,"
f"but get {src_checkpoint}.")
err_msg = ("The load_checkpoint must be a dir "
"and ckpt should be stored in the format of load_checkpoint/rank_x/xxx.ckpt, "
f"but get {src_checkpoint}.")
logger.error(err_msg)
raise ValueError(err_msg)
# Check rank_dirs is continuous.
# For example, rank_0, rank_1, rank_4 is not continuous because it is missing rank_3
src_checkpoint_rank_dir_list = glob(os.path.join(src_checkpoint, "rank_*"))
@@ -532,7 +554,9 @@ class TransformCkpt:
src_checkpoint_rank_num = len(src_checkpoint_rank_id_list)
for i in range(src_checkpoint_rank_num):
if src_checkpoint_rank_id_list[i] != i:
raise FileNotFoundError(f"The rank_{i} folder was not found under src_checkpoint folder.")
err_msg = f"The rank_{i} folder was not found under src_checkpoint folder."
logger.error(err_msg)
raise FileNotFoundError(err_msg)

# A full checkpoint do not require a strategy.
if len(src_checkpoint_rank_id_list) == 1 and src_strategy:
@@ -540,7 +564,9 @@ class TransformCkpt:
src_strategy = None
# Distributed checkpoints must be accompanied by strategy.
if len(src_checkpoint_rank_id_list) > 1 and src_strategy is None:
raise ValueError("`src_strategy` should not be None when `src_checkpoint` is sliced.")
err_msg = "`src_strategy` should not be None when `src_checkpoint` is sliced."
logger.error(err_msg)
raise ValueError(err_msg)

def send_strategy_to_obs(self, strategy):
"""Local rank send strategy file to obs."""
@@ -587,7 +613,9 @@ class TransformCkpt:
last_strategy_num = dst_strategy_num
if dst_strategy_num < self.world_size:
if time.time() - start_time > 7200:
raise TimeoutError("Timeout while collecting all strategy!")
err_msg = "Timeout while collecting all strategy!"
logger.error(err_msg)
raise TimeoutError(err_msg)
time.sleep(5)
else:
break
@@ -599,14 +627,16 @@ class TransformCkpt:
if check_in_modelarts():
transformed_ckpt_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs, os.path.basename(ckpt_dir))
transform_failed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs,
f'transform_failed_rank_*.txt'))
'transform_failed_rank_*.txt'))
transform_succeed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs,
f'transform_succeed_rank_*.txt'))
'transform_succeed_rank_*.txt'))
else:
transform_failed_txts = glob(os.path.join(ckpt_dir, f'transform_failed_rank_*.txt'))
transform_succeed_txts = glob(os.path.join(ckpt_dir, f'transform_succeed_rank_*.txt'))
transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt'))
transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt'))
if transform_failed_txts:
raise ValueError(f"Transform failed, find {transform_failed_txts}.")
err_msg = f"Transform failed, find {transform_failed_txts}."
logger.error(err_msg)
raise ValueError(err_msg)
current_count = len(transform_succeed_txts)
progress = (current_count / self.transform_process_num) * 100
if current_count != last_count:
@@ -617,6 +647,7 @@ class TransformCkpt:
else:
break


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--src_checkpoint',


+ 39
- 18
mindformers/tools/resume_ckpt.py View File

@@ -82,7 +82,9 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt',
if check_in_modelarts():
resume_record_dir = os.path.join(get_remote_save_url(), "resume_record")
if not Validator.is_obs_url(resume_record_dir):
raise ValueError(f"{resume_record_dir} is not a valid obs path.")
err_meg = f"{resume_record_dir} is not a valid obs path."
logger.error(err_meg)
raise ValueError(err_meg)
else:
resume_record_dir = os.path.join(get_output_root_path(), "resume_record")
remake_folder(resume_record_dir, permissions=0o750)
@@ -149,7 +151,9 @@ def checkpoint_health_monitor(health_ckpts_record_dir, resume_ckpt_list):
health_ckpts.append(item)

if not health_ckpts:
raise ValueError("The training has no healthy checkpoints yet, please start training again.")
err_msg = "The training has no healthy checkpoints yet, please start training again."
logger.error(err_msg)
raise ValueError(err_msg)

if not not_health_ckpts:
not_health_ckpts_set = set(not_health_ckpts)
@@ -165,12 +169,16 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):

if not check_in_modelarts():
if not os.path.exists(latest_checkpointed_iteration_txt):
raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}")
err_msg = f"Can not find {latest_checkpointed_iteration_txt}"
logger.error(err_msg)
raise ValueError(err_msg)
with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f:
resume_info = [line.strip() for line in f.readlines()]
else:
if not mox.file.exists(latest_checkpointed_iteration_txt):
raise ValueError(f"OBS: Can not find {latest_checkpointed_iteration_txt}")
err_msg = f"OBS: Can not find {latest_checkpointed_iteration_txt}"
logger.error(err_msg)
raise ValueError(err_msg)
with mox.file.File(latest_checkpointed_iteration_txt, 'r') as f:
resume_info = [line.strip() for line in f.readlines()]

@@ -180,7 +188,9 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id):
return True

if resume_info[0].startswith("Error"):
raise ValueError(f"Get resume-able checkpoint failed, due to {resume_info[0]}")
err_msg = f"Get resume-able checkpoint failed, due to {resume_info[0]}"
logger.error(err_msg)
raise ValueError(err_msg)

resume_ckpt = replace_rank_id_in_ckpt_name(resume_info[-1], rank_id)
logger.info("Get resume checkpoint: %s", resume_ckpt)
@@ -240,7 +250,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
ckpt_prefix_tmp = ckpt_prefix.replace(f"rank_{original_rank}", f"rank_{rank_id_tmp}")
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
if not os.path.exists(checkpoint_rank_dir):
raise FileNotFoundError(f"{checkpoint_rank_dir} is not found!")
err_msg = f"{checkpoint_rank_dir} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
for ckpt_file in os.listdir(checkpoint_rank_dir):
health_ckpt_match = (ckpt_file.startswith(ckpt_prefix_tmp[:ckpt_prefix_tmp.rfind("_")])
and use_checkpoint_health_monitor)
@@ -260,10 +272,14 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num,
ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id)
resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file)
if not os.path.exists(resume_ckpt):
raise FileNotFoundError(f"{resume_ckpt} is not found!")
err_msg = f"{resume_ckpt} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
resume_ckpt_list.append(resume_ckpt)
if not resume_ckpt_list:
raise RuntimeError("No checkpoint could be resumed.")
err_msg = "No checkpoint could be resumed."
logger.error(err_msg)
raise RuntimeError(err_msg)

if use_checkpoint_health_monitor:
resume_ckpt_list.sort(key=lambda x: get_times_epoch_and_step_from_ckpt_name(x, ckpt_format))
@@ -323,8 +339,9 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}")
last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
if not last_checkpoint:
raise ValueError(f"Checkpoint not found under {checkpoint_rank_dir} "
f"with config.load_ckpt_format:{ckpt_format}.")
err_msg = f"Checkpoint not found under {checkpoint_rank_dir} with config.load_ckpt_format:{ckpt_format}."
logger.error(err_msg)
raise ValueError(err_msg)
if check_ckpt_file_name(last_checkpoint, ckpt_format):
compared_checkpoint_name = replace_rank_id_in_ckpt_name(last_checkpoint, 0)
compared_original_checkpoint_name = os.path.basename(last_checkpoint)
@@ -351,12 +368,16 @@ def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='
compared_checkpoint_name = current_checkpoint_name
compared_original_checkpoint_name = original_checkpoint_name
elif compared_checkpoint_name != current_checkpoint_name:
raise ValueError(f"Check name of the checkpoint file with the last timestamp Failed.\n"
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
f"folders.\n 3. Rename `resume_training` checkpoint such as "
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
err_msg = (f"Check name of the checkpoint file with the last timestamp Failed.\n"
f"1. Find 2 different checkpoints name: {compared_original_checkpoint_name} and "
f"{original_checkpoint_name}.\n2. Checkpoint file name should follow rule: "
f"{{prefix}}-{{epoch}}_{{step}}.{ckpt_format}, and not corrupted across all rank "
f"folders.\n 3. Rename `resume_training` checkpoint such as "
f"llama_7b_rank_0-3_2.{ckpt_format} may solve the problem.")
logger.error(err_msg)
raise ValueError(err_msg)
if find_diff_ckpt:
raise ValueError(f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
f"naming convention, while others do not.")
err_msg = (f"Some checkpoints follow the {{prefix}}-{{epoch}}_{{step}}.{ckpt_format} "
f"naming convention, while others do not.")
logger.error(err_msg)
raise ValueError(err_msg)

+ 11
- 6
mindformers/trainer/base_trainer.py View File

@@ -1067,7 +1067,9 @@ class BaseTrainer:
logger.info(".............Start load resume context from common.json..................")
common_file = os.path.join(config.load_checkpoint, 'common.json')
if not os.path.exists(common_file):
raise FileNotFoundError(f"No common.json found in directory '{config.load_checkpoint}'.")
error_msg = f"No common.json found in directory '{config.load_checkpoint}'."
logger.error(error_msg)
raise FileNotFoundError(error_msg)
common_info = CommonInfo.load_common(common_file)
step_scale = common_info.global_batch_size / config.runner_config.global_batch_size
config.runner_config.initial_step = int(common_info.step_num * step_scale)
@@ -1101,9 +1103,11 @@ class BaseTrainer:
logger.info("..............Start resume checkpoint path from strategy..............")
resume_ckpt_path = self.resume_ckpt_path_with_strategy(config)
if resume_ckpt_path is None:
raise ValueError(f"Try to resume from checkpoints with strategy in directory "
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
f"specific checkpoint file to resume training.")
err_msg = (f"Try to resume from checkpoints with strategy in directory "
f"'{config.load_checkpoint}' failed, please specify load_checkpoint to "
f"specific checkpoint file to resume training.")
logger.error(err_msg)
raise ValueError(err_msg)
config.load_checkpoint = resume_ckpt_path
load_resume_context_from_checkpoint(config, dataset)
resume_dict = {
@@ -1184,8 +1188,9 @@ class BaseTrainer:
if hasattr(network, "get_model_parameters"):
model_params.update(network.get_model_parameters())
else:
raise NotImplementedError(f"The {type(network)} has not implemented the interface: "
f"get_model_parameters.")
err_msg = f"The {type(network)} has not implemented the interface: `get_model_parameters`."
logger.error(err_msg)
raise NotImplementedError(err_msg)

is_moe_model = False
is_mtp_model = False


+ 42
- 15
mindformers/trainer/utils.py View File

@@ -47,6 +47,7 @@ from mindformers.models.base_model import BaseModel
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.version_control import need_nz


# pylint: disable=import-outside-toplevel


@@ -58,9 +59,9 @@ class BaseEnum(str, Enum):
@classmethod
def _missing_(cls, value):
"""Enum with more explicit error message for missing values."""
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
)
err_msg = f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
logger.error(err_msg)
raise ValueError(err_msg)


class IntervalStrategy(BaseEnum):
@@ -134,12 +135,17 @@ def preload_ckpt(config):
return
mindio_pool_capacity = config.get("mindio_pool_capacity", 128)
set_mindio_server_info(mindio_pool_capacity)

if hasattr(_init_mindio(), "preload"):
logger.info("MindIO is initialized successfully!")
else:
return

if not os.path.realpath(ckpt_path) or not os.path.exists(ckpt_path):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {ckpt_path}")
err_log = f"The load_checkpoint must be correct, but get {ckpt_path}"
logger.error(err_log)
raise FileNotFoundError(err_log)

if os.path.isfile(ckpt_path):
# only preload the ckpt file once.
if is_main_rank() or f"rank_{rank_id}" in ckpt_path:
@@ -158,7 +164,10 @@ def preload_ckpt(config):
logger.info(f"MindIO preloading `{checkpoint_path}`...")
mindio_preload(checkpoint_path)
else:
raise ValueError(f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False.")
err_msg = f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False."
logger.error(err_msg)
raise ValueError(err_msg)

if config.use_parallel:
barrier()

@@ -198,8 +207,10 @@ def check_runner_config(config, dataset):
if config.runner_config.sink_mode:
if config.runner_config.sink_size != -1:
if config.runner_config.sink_size <= 0:
raise ValueError("per epoch size must be more than 0 or equal to -1, "
f"but get {config.runner_config.sink_size}")
err_log = (f"per epoch size must be more than 0 or equal to -1, "
f"but get {config.runner_config.sink_size}")
logger.error(err_log)
raise ValueError(err_log)
if data_size < config.runner_config.sink_size:
logger.warning("The data size %s (get from dataset.get_dataset_size()) is smaller "
"than the sink_size %s (get from config.runner_config.sink_size), "
@@ -321,7 +332,9 @@ def get_distribute_checkpoint_path(checkpoint_dir, rank_id=None, ckpt_format='ck
logger.info("Your load_checkpoint is file, it will be load in network.")
distribute_checkpoint_path = checkpoint_dir
else:
raise FileNotFoundError(f"{checkpoint_dir} is not found.")
err_msg = f"{checkpoint_dir} is not found."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
return distribute_checkpoint_path


@@ -349,7 +362,9 @@ def load_resume_context_from_checkpoint(config, dataset):
"""resume training, load training info from checkpoint to config"""
if not os.path.realpath(config.load_checkpoint) or \
not os.path.exists(config.load_checkpoint):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
err_log = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
logger.error(err_log)
raise FileNotFoundError(err_log)

if os.path.isdir(config.load_checkpoint):
# When graceful exit is enabled or auto checkpoint transformation is disabled,
@@ -376,6 +391,10 @@ def load_resume_context_from_checkpoint(config, dataset):
else:
checkpoint_tmp = os.path.join(config.load_checkpoint, f"rank_{rank_id}",
replace_rank_id_in_ckpt_name(config.resume_training, rank_id))
if not os.path.isfile(checkpoint_tmp):
err_msg = f"{checkpoint_tmp} is not found!"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
resume_dict = load_checkpoint(
checkpoint_tmp,
choice_func=lambda x: x in ["loss_scale", "epoch_num", "step_num", "global_batch_size"],
@@ -513,15 +532,20 @@ def transform_and_load_checkpoint(config, model, network, dataset, optimizer=Non


def check_checkpoint_config_valid(config):
"""Check valid load checkpoint path in config."""
# check valid load checkpoint path
if not config.only_save_strategy and (not os.path.realpath(config.load_checkpoint) or
not os.path.exists(config.load_checkpoint)):
raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}")
err_msg = f"The load_checkpoint must be correct, but get {config.load_checkpoint}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)

# check valid format
if config.load_ckpt_format is not None and config.load_ckpt_format not in CkptFormat.support_type():
raise ValueError(
f"config.load_ckpt_format only support for 'ckpt' or 'safetensors', but got {config.load_ckpt_format}.")
err_msg = ("config.load_ckpt_format only support for 'ckpt' or 'safetensors', "
f"but got {config.load_ckpt_format}.")
logger.error(err_msg)
raise ValueError(err_msg)


def check_path_include_total_ckpt(path):
@@ -571,7 +595,9 @@ def load_slora_ckpt(checkpoint_dict, config, network):
logger.info("............Start load slora checkpoint ............")
adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json")
if not os.path.exists(adapter_path):
raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}")
err_msg = f"The adapter_path must be correct, but get {adapter_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
with open(adapter_path, 'r', encoding='utf-8') as file:
path_dict = json.load(file)
adapter_list = []
@@ -688,8 +714,9 @@ def get_load_checkpoint_result(config):
else:
checkpoint_dict = load_distributed_checkpoint(config.load_checkpoint)
else:
raise ValueError(f"{config.load_checkpoint} is not a valid path to load checkpoint "
f"when auto_trans_ckpt is False.")
err_msg = f"{config.load_checkpoint} is not a valid path to load checkpoint when auto_trans_ckpt is False."
logger.error(err_msg)
raise ValueError(err_msg)
return checkpoint_dict if checkpoint_dict else checkpoint_future




+ 56
- 22
mindformers/utils/load_checkpoint_utils.py View File

@@ -94,9 +94,13 @@ def get_load_path_after_hf_convert(config, network):
def _check_checkpoint_path(path):
"""check checkpoint path."""
if not isinstance(path, str) or isinstance(path, os.PathLike):
raise ValueError(f"config.load_checkpoint must be a `str`, but got `{path}` as type `{type(path)}`.")
err_msg = f"config.load_checkpoint must be a `str`, but got `{path}` as type `{type(path)}`."
logger.error(err_msg)
raise ValueError(err_msg)
if not os.path.exists(path):
raise FileNotFoundError(f"config.load_checkpoint `{path}` does not exist.")
err_msg = f"config.load_checkpoint `{path}` does not exist."
logger.error(err_msg)
raise FileNotFoundError(err_msg)

if path[-1] == '/': # remove last '/' in path
return path[:-1]
@@ -116,7 +120,9 @@ def _get_checkpoint_mode(config):

# check path is dir
if not os.path.isdir(checkpoint_path):
raise ValueError("Provided path is neither a file nor a directory.")
err_msg = "Provided path is neither a file nor a directory."
logger.error(err_msg)
raise ValueError(err_msg)

dir_files = os.listdir(checkpoint_path)
if any(folder_name.startswith('rank_') for folder_name in dir_files):
@@ -125,7 +131,9 @@ def _get_checkpoint_mode(config):
if any(file_name.endswith(config.load_ckpt_format) for file_name in dir_files):
return CheckpointFileMode.MULTI_CHECKPOINT_FILE.value

raise ValueError("not support mode: no valid checkpoint files found")
err_msg = "not support mode: no valid checkpoint files found"
logger.error(err_msg)
raise ValueError(err_msg)


def _get_src_strategy(config):
@@ -142,8 +150,10 @@ def _get_src_strategy(config):
src_strategy_path = os.path.join(upper_dir, 'strategy')
logger.info(f"src_strategy_path_or_dir is empty, load source strategy from {src_strategy_path}.")
else:
raise ValueError("when use checkpoint after train/finetune, src_strategy_path_or_dir should be set "
"as a folder contained strategy ckpt files.")
err_msg = ("when use checkpoint after train/finetune, src_strategy_path_or_dir should be set "
"as a folder contained strategy ckpt files.")
logger.error(err_msg)
raise ValueError(err_msg)
logger.info(f"load source strategy from {src_strategy_path}.")
return src_strategy_path

@@ -249,7 +259,9 @@ def _get_src_file(checkpoint_dir, checkpoint_name=None, ckpt_format='ckpt'):
else:
ckpt_path = get_last_checkpoint(checkpoint_rank_dir, ckpt_format)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"{ckpt_path} is not found.")
err_msg = f"{ckpt_path} is not found."
logger.error(err_msg)
raise FileNotFoundError(err_msg)
return ckpt_path


@@ -265,7 +277,9 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval

pet_config = config.model.model_config.get("pet_config")
if pet_config and pet_config.pet_type == "slora" and network.lora_list:
raise ValueError(f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon.")
err_msg = f"slora only support .ckpt file, {config.load_ckpt_format} file will be compatible soon."
logger.error(err_msg)
raise ValueError(err_msg)
ckpt_file_mode = _get_checkpoint_mode(config)
validate_config_with_file_mode(ckpt_file_mode, config.use_parallel, config.auto_trans_ckpt)
# reduce compile time in prediction
@@ -369,18 +383,26 @@ def validate_config_with_file_mode(ckpt_file_mode, use_parallel, auto_trans_ckpt
"""validate use_parallel and auto_trans_ckpt config with different file mode"""
if ckpt_file_mode == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value:
if use_parallel:
raise ValueError("When load checkpoint is a single file and use_parallel is True, please change "
"load_checkpoint in yaml from file name to the directory where only this file is located.")
err_msg = ("When load checkpoint is a single file and use_parallel is True, please change "
"load_checkpoint in yaml from file name to the directory where only this file is located.")
logger.error(err_msg)
raise ValueError(err_msg)
elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value:
if use_parallel and not auto_trans_ckpt:
raise ValueError("When load checkpoint is complete and use_parallel is True, please set auto_trans_ckpt: "
"True to enable automatic slicing function.")
err_msg = ("When load checkpoint is complete and use_parallel is True, please set auto_trans_ckpt: "
"True, to enable automatic slicing function.")
logger.error(err_msg)
raise ValueError(err_msg)
elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value:
if not use_parallel:
raise ValueError("when input checkpoint file is rank dir, Please set use_parallel: True to enable "
"distributed ckpt load.")
err_msg = ("when input checkpoint file is rank dir, Please set use_parallel: "
"True, to enable distributed ckpt load.")
logger.error(err_msg)
raise ValueError(err_msg)
else:
raise ValueError("not support mode: no valid checkpoint files found")
err_msg = "not support mode: no valid checkpoint files found"
logger.error(err_msg)
raise ValueError(err_msg)


def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_parallel=False,
@@ -425,6 +447,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
logger.info("......obtain name map for HF safetensors.....")
name_map = origin_network.obtain_name_map(load_checkpoint_files)
except Exception as e:
logger.error(f"Please complete abstract function obtain_name_map. Details: {e}")
raise TypeError(f"Please complete abstract function obtain_name_map. Details: {e}") from e
if is_main_rank():
_convert_index_json(load_ckpt_path, load_ckpt_path, origin_network.convert_map_dict, False)
@@ -442,7 +465,9 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy
hyper_param_file = os.path.join(load_ckpt_path, 'hyper_param.safetensors')
if optimizer and config.resume_training:
if not os.path.exists(hyper_param_file):
raise FileNotFoundError(rf"No hyper_param.safetensors in given dir: {load_ckpt_path}")
err_msg = rf"No hyper_param.safetensors in given dir: {load_ckpt_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)
logger.info("......Start load hyper param into optimizer......")
hyper_param_dict = ms.load_checkpoint(ckpt_file_name=hyper_param_file, format='safetensors')
update_global_step(config, hyper_param_dict)
@@ -508,7 +533,9 @@ def process_hf_checkpoint(model, output_dir=None, load_checkpoint=None):

# If the child process exits abnormally, the main process should also throw an exception.
if p.exitcode != 0:
raise RuntimeError("convert HuggingFace weight failed.")
err_msg = "convert HuggingFace weight failed."
logger.error(err_msg)
raise RuntimeError(err_msg)

# If used parallel mode, other cards are waiting for the main card to complete the weight conversion.
barrier_world("for the main rank to convert HuggingFace weight...")
@@ -521,11 +548,14 @@ def process_hf_checkpoint(model, output_dir=None, load_checkpoint=None):
def get_last_checkpoint(checkpoint_dir, ckpt_format='ckpt'):
"""get last checkpoint for resuming or finetune."""
if not os.path.isdir(checkpoint_dir):
raise NotADirectoryError(
err_msg = (
f"{checkpoint_dir} is not a real directory,"
f"When distributed loads are sliced weights,"
f"load_checkpoint should be a checkpoint directory containing the directory of rank_{{0-*}},"
f"The directory structure is as follows: **checkpoint_root_dir/rank_{{0-*}}/**.{ckpt_format}")
logger.error(err_msg)
raise NotADirectoryError(err_msg)

output_checkpoint_path = [
checkpoint
for checkpoint in os.listdir(checkpoint_dir)
@@ -565,11 +595,15 @@ def validate_qkv_concat(model_cls_or_instance, qkv_concat_config, load_checkpoin
break

if is_qkv_concat and not qkv_concat_config:
raise ValueError("The qkv concat check failed! The qkv in the model weights has been concatenated,"
" but qkv_concat is set to false.")
err_msg = ("The qkv concat check failed! The qkv in the model weights has been concatenated, "
"but qkv_concat is set to false.")
logger.error(err_msg)
raise ValueError(err_msg)
if not is_qkv_concat and qkv_concat_config:
raise ValueError("The qkv concat check failed! The qkv in the model weights has been not concatenated,"
" but qkv_concat is set to true.")
err_msg = ("The qkv concat check failed! The qkv in the model weights has been not concatenated, "
"but qkv_concat is set to true.")
logger.error(err_msg)
raise ValueError(err_msg)
if is_qkv_concat and qkv_concat_config:
logger.info("The qkv concat check succeed! The qkv in the model weights has been concatenated and "
"qkv_concat is set to true.")


+ 3
- 1
mindformers/utils/resume_ckpt_utils.py View File

@@ -56,7 +56,9 @@ def load_resume_checkpoint(load_checkpoint_path, remove_redundancy, load_ckpt_fo
"""resume training, load training info from checkpoint to config"""
if not os.path.realpath(load_checkpoint_path) or \
not os.path.exists(load_checkpoint_path):
raise FileNotFoundError(f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}")
err_msg = f"The load_checkpoint_path must be correct, but get {load_checkpoint_path}"
logger.error(err_msg)
raise FileNotFoundError(err_msg)

if os.path.isdir(load_checkpoint_path):
hyper_param_file = os.path.join(load_checkpoint_path, 'hyper_param.safetensors')


Loading…
Cancel
Save
Baidu
map