2 Commits

Author SHA1 Message Date
  i-robot dff083c221
!7778 adapt group_lr for new llm_trainer. 5 days ago
  Lin-Bert 14efdddfef adapt group_lr for new llm_trainer. 2 weeks ago
6 changed files with 434 additions and 135 deletions
Split View
  1. +106
    -54
      configs/llm_template_v2_experimental/llm_config_whitelist_introduce.md
  2. +11
    -11
      mindformers/core/lr/lr_schedule.py
  3. +77
    -1
      mindformers/tools/register/llm_template_v2.py
  4. +231
    -65
      mindformers/trainer/llm_trainer_for_graph_experimental/llm_trainer.py
  5. +8
    -3
      mindformers/trainer/optimizer_grouped_parameters.py
  6. +1
    -1
      tests/st/test_ut/test_tools/test_template_v2/test_template_v2.py

+ 106
- 54
configs/llm_template_v2_experimental/llm_config_whitelist_introduce.md View File

@@ -11,10 +11,9 @@
- [CosineWithWarmUpLR](#cosinewithwarmuplr)
- [CosineWithRestartsAndWarmUpLR](#cosinewithrestartsandwarmuplr)
- [PolynomialWithWarmUpLR](#polynomialwithwarmuplr)
- [CosineAnnealingLR](#cosineannealinglr)
- [CosineAnnealingWarmRestarts](#cosineannealingwarmrestarts)
- [WarmUpStableDecayLR](#warmupstabledecaylr)
- [ConstantWithCoolDownLR](#constantwithcooldownlr)
- [参数分组学习率](#参数分组学习率)
- [优化器配置](#优化器配置)
- [AdamW优化器](#adamw优化器)
- [PmaAdamW优化器](#pmaadamw优化器)
@@ -415,55 +414,6 @@ lr_schedule: # 多项式衰减学习率调度器配置
decay_steps: null
```

#### CosineAnnealingLR

```yaml
lr_schedule: # 余弦退火学习率调度器配置
# type: 学习率调度器类型,CosineAnnealingLR余弦退火学习率调度器
# 该调度器按照余弦函数周期性地衰减学习率,在每个周期内从base_lr衰减到eta_min
# 当训练步数达到t_max的倍数时,学习率会重启到base_lr,开始新的周期
type: CosineAnnealingLR

# base_lr: 基础学习率值,每个重启周期的初始学习率
# 余弦退火会以此为起点,逐渐衰减到eta_min
base_lr: 1.e-6

# t_max: 余弦退火周期的步数,控制学习率衰减周期的长度
# 当训练步数达到t_max的倍数时,学习率会重启到base_lr,开始新的余弦退火周期
# 需要根据算法需求和训练总步数进行设置,通常设置为训练总步数的整数分之一
t_max: 10

# eta_min: 最小学习率值,余弦退火衰减的最终学习率
# 学习率会在每个周期内从base_lr衰减到该值
# null表示使用默认值0,也可以设置为一个很小的正数
eta_min: null
```

#### CosineAnnealingWarmRestarts

```yaml
lr_schedule: # 余弦退火重启学习率调度器配置
# type: 学习率调度器类型,CosineAnnealingWarmRestarts余弦退火重启学习率调度器
# 该调度器按照余弦函数周期性地衰减学习率,并在每个周期结束时重启学习率
type: CosineAnnealingWarmRestarts

# base_lr: 基础学习率值,每个重启周期的初始学习率
# 余弦退火会以此为起点,逐渐衰减到eta_min
base_lr: 1.e-6

# t_0: 第一个重启周期的步数,控制第一个学习率周期的长度
# 当训练步数达到t_0的倍数时,学习率会重启到base_lr,需要根据算法需求人工进行设置
t_0: 10

# t_mult: 周期倍数因子,控制后续重启周期长度的倍数
# 当t_mult=1时,所有周期长度相同;当t_mult>1时,后续周期会逐渐变长
t_mult: 1.

# eta_min: 最小学习率值,余弦退火衰减的最终学习率
# 学习率会在每个周期内从base_lr衰减到该值
eta_min: null
```

#### WarmUpStableDecayLR

```yaml
@@ -565,6 +515,108 @@ lr_schedule: # 恒定学习率带冷却调度器配置
decay_ratio: 0.
```

#### 参数分组学习率

```yaml
# grouped_lr_schedule: 参数分组学习率调度器配置
# 该配置允许为模型的不同参数组设置独立的学习率调度策略,实现细粒度的学习率控制
# 适用于迁移学习、微调、逐层衰减等场景,可以为不同组件(如embedding、attention、MLP等)设置不同的学习率
grouped_lr_schedule:
# default: 默认学习率调度器配置,应用于所有不匹配任何分组模式的参数
# 当某个参数不匹配任何grouped中的params模式时,将使用此默认调度器
default:
# type: 默认学习率调度器类型,支持所有标准的学习率调度器类型
# 例如:ConstantWarmUpLR、CosineWithWarmUpLR、PolynomialWithWarmUpLR等
type: "ConstantWarmUpLR"

# learning_rate: 默认学习率值,用于所有未匹配到特定分组的参数
# 该值将作为默认调度器的基础学习率
learning_rate: 1.e-4

# warmup_ratio: warmup比例,控制预热阶段占总训练步数的比例
# 0表示不使用学习率预热,直接使用恒定学习率,数值范围为[0, 1]
# 不为0时,会覆盖warmup_steps的设置, 且warmup_steps=total_steps*warmup_ratio
warmup_ratio: 0.

# total_steps: 总训练步数,-1表示使用默认的数据集大小计算出的总步数
# 设置为正整数时,将使用指定的步数作为训练总步数,如果设置了stop_step训练提前退出功能,建议手动调整该参数
total_steps: -1

# 注意:default配置中还可以包含对应调度器类型的其他参数
# 例如:warmup_steps、warmup_ratio、total_steps等,具体参数取决于选择的调度器类型

# grouped: 参数分组列表,每个分组包含参数匹配模式和对应的学习率调度器配置
# 系统会按照列表顺序匹配参数,第一个匹配的分组将被应用
grouped:
# 第一个参数分组:匹配所有embedding相关的参数
- # type: 该参数组的学习率调度器类型
# 可以为不同参数组选择不同的调度策略,实现差异化的学习率调整
type: "CosineWithWarmUpLR"

# params: 参数名称匹配模式列表,支持通配符模式(使用fnmatch进行匹配)
# 可以使用通配符 "*" 匹配任意字符序列,"?" 匹配单个字符
# 例如:"*.embedding*" 会匹配所有包含"embedding"的参数名
# 参数匹配是大小写敏感的,且支持多个模式,只要参数名匹配任一模式即可
params: ['embedding.*', 'output_layer.weight']

# learning_rate: 该参数组的基础学习率值
# 通常embedding层使用较小的学习率,以保持预训练的特征表示
learning_rate: 1.e-5

# warmup_ratio: warmup比例,控制预热阶段占总训练步数的比例
# 0表示不使用学习率预热,直接使用恒定学习率,数值范围为[0, 1]
# 不为0时,会覆盖warmup_steps的设置, 且warmup_steps=total_steps*warmup_ratio
warmup_ratio: 0.

# total_steps: 总训练步数,-1表示使用默认的数据集大小计算出的总步数
# 设置为正整数时,将使用指定的步数作为训练总步数,如果设置了stop_step训练提前退出功能,建议手动调整该参数
total_steps: -1

# 注意:每个分组可以包含对应调度器类型的完整配置参数
# 例如:warmup_steps、warmup_ratio、total_steps、lr_end等
# 这些参数仅对该分组内的参数生效

# 第二个参数分组:匹配所有attention相关的参数
- # type: attention层使用的学习率调度器类型
# 可以为attention层选择更激进的学习率策略
type: "PolynomialWithWarmUpLR"

# params: 匹配所有attention层相关的参数
# 例如:"*.attention*" 会匹配 "decoder.layers.0.self_attention.linear_qkv.weight" 等参数
params: ["*.self_attention*"]

# learning_rate: attention层的基础学习率值
# 通常attention层可以使用较大的学习率,以便快速适应新任务
learning_rate: 2.e-4

# warmup_ratio: warmup比例,控制预热阶段占总训练步数的比例
# 0表示不使用学习率预热,直接使用恒定学习率,数值范围为[0, 1]
# 不为0时,会覆盖warmup_steps的设置, 且warmup_steps=total_steps*warmup_ratio
warmup_ratio: 0.

# total_steps: 总训练步数,-1表示使用默认的数据集大小计算出的总步数
# 设置为正整数时,将使用指定的步数作为训练总步数,如果设置了stop_step训练提前退出功能,建议手动调整该参数
total_steps: -1
# 注意:可以继续添加更多参数分组,每个分组都有独立的调度器配置
# 例如:可以为MLP层、LayerNorm层等分别设置不同的学习率策略

# 使用说明:
# 1. 参数匹配优先级:系统按照grouped列表的顺序依次匹配参数,第一个匹配的分组将被应用
# 2. 默认调度器:如果参数不匹配任何grouped中的模式,将使用default配置的调度器
# 3. 通配符匹配:支持fnmatch风格的通配符,如"*.layer.*.weight"可以匹配多层结构
# 4. 配置完整性:每个分组(包括default)都需要包含对应调度器类型所需的完整参数
# 5. 使用场景:
# - 迁移学习:预训练层使用小学习率,新添加层使用大学习率
# - 微调:embedding层冻结或使用极小学习率,上层使用正常学习率
# - 逐层衰减:不同层使用不同的学习率衰减策略
# - 组件优化:为attention、MLP等不同组件设置不同的学习率
# 6. 注意事项:
# - 确保参数名称模式能够正确匹配到目标参数, 可以通过任意配之后运行程序查看匹配结果,会打印出支持的匹配参数
# - 建议在训练前检查参数匹配情况,确保所有参数都被正确分组
# - 不同分组的total_steps应该保持一致,以确保训练步数同步
# - 如果某个参数同时匹配多个模式,将使用第一个匹配的分组配置
```

### 优化器配置

#### AdamW优化器
@@ -572,7 +624,7 @@ lr_schedule: # 恒定学习率带冷却调度器配置
```yaml
# Optimizer configuration
# 优化器配置,用于指定训练过程中使用的优化器类型及其相关参数
optimizer: # 优化器1
optimizer:
# type: 优化器类型
type: AdamW

@@ -609,7 +661,7 @@ optimizer: # 优化器1
#### PmaAdamW优化器

```yaml
optimizer: # 优化器2
optimizer:
# type: 优化器类型,PmaAdamW优化器
# Pre-trained Model Average(PMA)权重合并是指在训练过程中,
# 根据选择 Exponential Moving Average(EMA)算法或 Simple Moving Average(SMA)算法对权重进行融合合并,从而提升模型训练的效果。
@@ -676,7 +728,7 @@ Muon优化器具有以下特点:
3. 在MoE模型中,专家数量必须能被(optimizer_weight_shard_size * expert_model_parallel_size)整除

```yaml
optimizer: # 优化器3 - Muon优化器配置
optimizer:
# type: 优化器类型,指定使用Muon优化器
type: Muon



+ 11
- 11
mindformers/core/lr/lr_schedule.py View File

@@ -36,9 +36,9 @@ def _get_lr_steps(steps: int, ratio: float, total_steps: int, phase_tag: str):
"""check args and get specified steps."""
if ratio is None:
if not isinstance(steps, int):
raise TypeError(f"The type of {phase_tag}_step must be int, but got {type(steps)}")
raise TypeError(f"The type of {phase_tag}_steps must be int, but got {type(steps)}")
if steps < 0:
raise ValueError(f"The {phase_tag}_step must be >= 0, but got {steps}")
raise ValueError(f"The {phase_tag}_steps must be >= 0, but got {steps}")
return steps

if not isinstance(ratio, (float, int)):
@@ -173,7 +173,7 @@ class ConstantWithCoolDownLR(LearningRateSchedule):
lr_end2: float = None,
**kwargs
):
super(ConstantWithCoolDownLR, self).__init__()
super().__init__()
warmup_steps_ = _get_lr_steps(warmup_steps, warmup_ratio, total_steps, "warmup")
decay_steps = max(1, decay_steps) if decay_steps is not None else max(1, total_steps)
decay_steps_ = _get_lr_steps(decay_steps, decay_ratio, total_steps, "decay")
@@ -280,7 +280,7 @@ class ConstantWarmUpLR(LearningRateSchedule):
)
def __init__(self, learning_rate: float, warmup_steps: int = None, warmup_lr_init: float = 0.,
warmup_ratio: float = None, total_steps: int = None, **kwargs):
super(ConstantWarmUpLR, self).__init__()
super().__init__()
warmup_steps = _get_lr_steps(warmup_steps, warmup_ratio, total_steps, "warmup")
self.learning_rate = learning_rate
self.warmup_lr_init = warmup_lr_init
@@ -367,7 +367,7 @@ class LinearWithWarmUpLR(LearningRateSchedule):
def __init__(self, learning_rate: float, total_steps: int, warmup_steps: int = None,
warmup_lr_init: float = 0., warmup_ratio: float = None,
**kwargs):
super(LinearWithWarmUpLR, self).__init__()
super().__init__()
warmup_steps = _get_lr_steps(warmup_steps, warmup_ratio, total_steps, "warmup")
linear_steps = max(1, total_steps - warmup_steps)
self.kwargs = kwargs
@@ -463,7 +463,7 @@ class CosineWithWarmUpLR(LearningRateSchedule):
def __init__(self, learning_rate: float, warmup_steps: int = 0, total_steps: int = None,
num_cycles: float = 0.5, lr_end: float = 0., warmup_lr_init: float = 0.,
warmup_ratio: float = None, decay_steps: int = None, decay_ratio: float = None, **kwargs):
super(CosineWithWarmUpLR, self).__init__()
super().__init__()
_check_decay_method(decay_steps, total_steps)
warmup_steps = _get_lr_steps(warmup_steps, warmup_ratio, total_steps, "warmup")
cosine_steps = max(1, total_steps - warmup_steps)
@@ -568,7 +568,7 @@ class CosineWithRestartsAndWarmUpLR(LearningRateSchedule):
def __init__(self, learning_rate: float, warmup_steps: int = None, total_steps: int = None,
num_cycles: float = 1., lr_end: float = 0., warmup_lr_init: float = 0.,
warmup_ratio: float = None, decay_steps: int = None, **kwargs):
super(CosineWithRestartsAndWarmUpLR, self).__init__()
super().__init__()
_check_decay_method(decay_steps, total_steps)
warmup_steps = _get_lr_steps(warmup_steps, warmup_ratio, total_steps, "warmup")
cosine_steps = max(1, total_steps - warmup_steps)
@@ -687,7 +687,7 @@ class PolynomialWithWarmUpLR(LearningRateSchedule):
def __init__(self, learning_rate: float, total_steps: int, warmup_steps: int = None,
lr_end: float = 1e-7, power: float = 1.0, warmup_lr_init: float = 0.,
warmup_ratio: float = None, decay_steps: int = None, **kwargs):
super(PolynomialWithWarmUpLR, self).__init__()
super().__init__()
_check_decay_method(decay_steps, total_steps)
warmup_steps = _get_lr_steps(warmup_steps, warmup_ratio, total_steps, "warmup")
decay_steps = max(1, decay_steps) \
@@ -783,7 +783,7 @@ class LearningRateWiseLayer(LearningRateSchedule):
"""

def __init__(self, base_lr, lr_scale):
super(LearningRateWiseLayer, self).__init__()
super().__init__()
self.base_lr = base_lr
self.lr_scale = lr_scale

@@ -943,7 +943,7 @@ class CosineAnnealingLR(LearningRateSchedule):

@args_type_check(base_lr=(int, float), t_max=int, eta_min=(int, float))
def __init__(self, base_lr: float, t_max: int, eta_min: float = 0., **kwargs):
super(CosineAnnealingLR, self).__init__()
super().__init__()
if t_max < 1 or not isinstance(t_max, int):
raise ValueError(f"Expected positive integer T_max, but got {t_max}")
self.kwargs = kwargs
@@ -1015,7 +1015,7 @@ class CosineAnnealingWarmRestarts(LearningRateSchedule):

@args_type_check(base_lr=(int, float), t_0=int, t_mult=int, eta_min=(int, float))
def __init__(self, base_lr: float, t_0: int, t_mult: int = 1, eta_min: float = 0., **kwargs):
super(CosineAnnealingWarmRestarts, self).__init__()
super().__init__()
if t_0 < 1 or not isinstance(t_0, int):
raise ValueError(f"Expected positive integer t_0, but got {t_0}")
if t_mult < 1 or not isinstance(t_mult, int):


+ 77
- 1
mindformers/tools/register/llm_template_v2.py View File

@@ -20,6 +20,13 @@ from mindformers.tools.logger import logger
from .validate_types_and_ranges import validate_config_types_and_ranges


LR_SUPPORT_LIST = \
["ConstantWarmUpLR", "LinearWithWarmUpLR",
"CosineWithWarmUpLR", "CosineWithRestartsAndWarmUpLR",
"PolynomialWithWarmUpLR", "WarmUpStableDecayLR",
"ConstantWithCoolDownLR"]


class Config:
"""
A base class for applying structured configuration.
@@ -1337,12 +1344,69 @@ class LrScheduleConfig(SpecConfig):

type: str = "CosineWithWarmUpLR"
learning_rate: float = 5.e-5
lr_end: float = 0.
warmup_lr_init: float = 0.
warmup_ratio: float = 0.
total_steps: int = -1

_raise_error_for_unexpected_key: bool = False

_validation_rules = {
"type": {
"type": str,
"range": LR_SUPPORT_LIST,
"description": "Learning rate schedule type"
},
"learning_rate": {
"type": (float, int),
"range": None,
"description": "Learning rate"
},
"warmup_lr_init": {
"type": (float, int),
"range": None,
"description": "Initial learning rate during warmup"
},
"warmup_ratio": {
"type": (float, int),
"range": None,
"description": "Ratio of warmup steps to total steps"
},
"total_steps": {
"type": int,
"range": None,
"description": "Total number of training steps"
}
}


class GroupedLrScheduleConfig(SpecConfig):
"""Learning rate schedule configuration"""
_name: str = "grouped_lr_schedule"
_required_keys: Dict[str, List[str]] = {
"grouped_lr_schedule": ["default", "grouped"]
}

_raise_error_for_unexpected_key: bool = False

_validation_rules = {
"default": {
"type": dict,
"range": lambda x: all([x.get("type") in LR_SUPPORT_LIST,
(x.get("warmup_ratio") is not None
or x.get("warmup_steps") is not None)]),
"description": "Default learning rate schedule"
},
"grouped": {
"type": list,
"range": lambda x: all([all(isinstance(grouped_lr, dict) for grouped_lr in x),
all(grouped_lr.get("type") in LR_SUPPORT_LIST for grouped_lr in x),
all(isinstance(grouped_lr.get("params"), list)for grouped_lr in x),
all((grouped_lr.get("warmup_ratio") is not None or
grouped_lr.get("warmup_steps") is not None) for grouped_lr in x)]),
"description": "List of learning rate schedules for different groups"
}
}


class CallbackConfig(ListConfig):
"""Callback configuration"""
@@ -1698,6 +1762,7 @@ CONFIG_NAME_TO_CLASS: Dict[str, type] = {
"model_config": ModelConfig,
"optimizer": OptimizerConfig,
"lr_schedule": LrScheduleConfig,
"grouped_lr_schedule": GroupedLrScheduleConfig,
"callbacks": CallbackConfig,
"monitor_config": MonitorConfig,
"profile": ProfileConfig,
@@ -1733,6 +1798,7 @@ class ConfigTemplate:
"model_config",
"optimizer",
"lr_schedule",
"grouped_lr_schedule",
"callbacks",
"monitor_config",
"profile",
@@ -1747,6 +1813,13 @@ class ConfigTemplate:
"trainer",
"model_config"]

# config module that should not be merged when it is not in input config
_default_not_merge_configs: List[str] = [
"lr_schedule",
"grouped_lr_schedule",
"callbacks"
]

_run_modes: List[str] = ['train', 'predict', 'finetune']

@classmethod
@@ -1788,6 +1861,9 @@ class ConfigTemplate:

new_config = {}
for sub_config in template:
if sub_config not in config.keys() and sub_config in cls._default_not_merge_configs:
continue

if sub_config == "distribute_parallel_config":
origin_sub_config = config.pop(sub_config, None)
cls.update_distributed_parallel_config(sub_config, new_config, origin_sub_config, run_mode)


+ 231
- 65
mindformers/trainer/llm_trainer_for_graph_experimental/llm_trainer.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trainer API For Import."""
"""LLMTrainer For Graph Mode."""
import os
import subprocess
from pprint import pprint
@@ -95,6 +95,8 @@ class LLMTrainer:
self.common_restore_info = None
self.network_delay_inited = False
self.optimizer_delay_inited = False
self.lr_scheduler = None
self.grouped_lr_scheduler = None

def _setup_config(self, config: MindFormerConfig, is_train: bool = True) -> None:
"""Initialize and setup configuration for training or inference.
@@ -286,7 +288,7 @@ class LLMTrainer:
self.config.model.model_config = self.config.model_config
del self.config.model_config

def _set_and_logging_training_step(self, dataset) -> None:
def _set_and_logging_training_step(self) -> None:
"""Check runner config and set training step parameters.

This method calculates and configures the training steps based on the dataset size,
@@ -294,9 +296,6 @@ class LLMTrainer:
is enabled and sink_size is specified. It also sets initial epoch and step values
for training resumption.

Args:
dataset: Training dataset used to determine data size for calculations.

The method performs the following operations:
1. Gets the training dataset size
2. Sets original epochs value
@@ -305,7 +304,7 @@ class LLMTrainer:
5. Adjusts epochs calculation based on sink mode and sink size
6. Updates configuration with dataset size and training parameters
"""
data_size = self._get_train_dataset_size(dataset)
data_size = self._get_train_dataset_size()
new_epochs = self.config.training_args.epochs
self.config.training_args.origin_epochs = new_epochs

@@ -683,6 +682,61 @@ class LLMTrainer:
logger.info(f'load_checkpoint config is None, which is invalid for run_mode: {self.config.run_mode}')
return load_checkpoint_path_or_file

def _set_learning_rate_scheduler(
self,
lr_scheduler: Optional[nn.learning_rate_schedule.LearningRateSchedule] = None,
grouped_lr_scheduler: List[dict] = None) -> None:
"""Set learning rate scheduler(s) for model training.

This method configures the learning rate scheduling strategy for the optimizer.
It supports two modes:
1. Standard mode: A single learning rate scheduler applied to all model parameters
2. Grouped mode: Different learning rate schedulers for different parameter groups

The grouped learning rate scheduler allows fine-grained control over learning rates
for different parts of the model (e.g., different layers, embeddings, or attention
mechanisms), which is useful for transfer learning, fine-tuning, or when different
components require different learning rate schedules.

Args:
lr_scheduler (Optional[nn.learning_rate_schedule.LearningRateSchedule]):
A single learning rate scheduler instance to be applied uniformly to all
model parameters. This is used when all parameters should follow the same
learning rate schedule. Examples include CosineAnnealingLR, PolynomialDecayLR, etc.
If None, the scheduler will not be set (useful when only grouped schedulers are used).

grouped_lr_scheduler (List[dict], optional):
A list of dictionaries, where each dictionary represents a parameter group
with its own learning rate scheduler. Each dictionary should contain:
- 'params': List[str] - Parameter name patterns to match (supports wildcards)
- 'lr_scheduler': LearningRateSchedule - The scheduler instance for this group
- 'lr_config': MindFormerConfig - Configuration used to create the scheduler

This allows different parameter groups to have independent learning rate schedules.
For example, embeddings might use a different schedule than transformer layers.
If None, grouped learning rate scheduling will not be used.

Note:
- Both schedulers can be set simultaneously. When both are provided, the grouped
scheduler takes precedence for matching parameters, while the standard scheduler
applies to unmatched parameters.
- This method is typically called during optimizer creation in
`create_optimizer_scheduler` method.
- The method will log informational messages when schedulers are successfully set.

Side Effects:
- Sets `self.lr_scheduler` if lr_scheduler is provided
- Sets `self.grouped_lr_scheduler` if grouped_lr_scheduler is provided
- Logs configuration information for debugging purposes
"""
if lr_scheduler is not None:
self.lr_scheduler = lr_scheduler
logger.info("The default learning rate scheduler has been set.")

if grouped_lr_scheduler is not None:
self.grouped_lr_scheduler = grouped_lr_scheduler
logger.info("The group learning rate scheduler has been set.")

def _train_dataset_restore_from_checkpoint(self, dataset: GeneratorDataset,
load_checkpoint_path_or_file: Optional[str]) -> None:
"""Restore training dataset state from checkpoint.
@@ -1137,7 +1191,7 @@ class LLMTrainer:

return model

def create_optimizer_scheduler(self, network: nn.Cell, dataset: GeneratorDataset) -> nn.Optimizer:
def create_optimizer_scheduler(self, network: nn.Cell) -> nn.Optimizer:
"""Create optimizer and learning rate scheduler for model training.

This method constructs an optimizer with grouped parameters and learning rate schedule
@@ -1146,7 +1200,6 @@ class LLMTrainer:

Args:
network (nn.Cell): The neural network model to optimize.
dataset (GeneratorDataset): Training dataset used for learning rate scheduling.

Returns:
nn.Optimizer: Configured optimizer with learning rate scheduler.
@@ -1165,7 +1218,13 @@ class LLMTrainer:
logger.info("Building optimizer and learning rate scheduler from configuration...")

# Build learning rate schedule
lr_schedule = self.create_lr_scheduler(dataset)
is_grouped_lr_scheduler = self.config.grouped_lr_schedule is not None
grouped_lr_scheduler = None
if is_grouped_lr_scheduler:
lr_scheduler, grouped_lr_scheduler = self.create_grouped_lr_scheduler()
else:
lr_scheduler = self.create_lr_scheduler()
self._set_learning_rate_scheduler(lr_scheduler, grouped_lr_scheduler)

optimizer_type = self.config.optimizer.type

@@ -1181,7 +1240,10 @@ class LLMTrainer:
# Group parameters with weight decay
weight_decay = self.config.optimizer.weight_decay if self.config.optimizer.weight_decay else 0.
group_params = get_optimizer_grouped_parameters(
network, weight_decay, optimizer_type=optimizer_type, model_params=model_params)
network, weight_decay,
optimizer_type=optimizer_type,
grouped_lr_schedule=grouped_lr_scheduler,
model_params=model_params)

def create_optimizer() -> nn.Optimizer:
"""Internal function to create optimizer instance.
@@ -1193,8 +1255,8 @@ class LLMTrainer:
ValueError: If learning_rate is not set in optimizer config.
"""
optimizer_kwargs = {"params": group_params}
if lr_schedule is not None:
optimizer_kwargs["learning_rate"] = lr_schedule
if lr_scheduler is not None:
optimizer_kwargs["learning_rate"] = lr_scheduler
else:
if self.config.optimizer.learning_rate is None:
raise ValueError("lr_schedule is None, please set learning_rate in optimizer config.")
@@ -1221,68 +1283,81 @@ class LLMTrainer:
logger.info("Optimizer and learning rate scheduler created successfully")
return optimizer

def create_lr_scheduler(self, dataset) -> nn.learning_rate_schedule.LearningRateSchedule:
def create_lr_scheduler(
self, lr_schedule_config: MindFormerConfig = None) -> nn.learning_rate_schedule.LearningRateSchedule:
"""Create learning rate scheduler based on configuration.

This method builds a learning rate scheduler by processing the configuration settings
and calculating the appropriate parameters for the learning rate schedule.
and calculating the appropriate parameters for the learning rate schedule. It handles
warmup configuration, total steps calculation, and scheduler instantiation.

Args:
dataset: Training dataset used to calculate total training steps.
lr_schedule_config (MindFormerConfig, optional):
Learning rate scheduler configuration. If None, uses `self.config.lr_schedule`.
The configuration should contain:
- `type`: str - Scheduler type (e.g., "CosineWithWarmUpLR", "PolynomialWithWarmUpLR",
"ConstantWarmUpLR", "CosineAnnealingLR", etc.)
- `learning_rate`: float - Base learning rate value
- `warmup_ratio`: float, optional - Warmup ratio (required if warmup_epochs is set)
- `warmup_steps`: int, optional - Number of warmup steps (can be set directly)
- `warmup_lr_init`: float, optional - Initial learning rate during warmup
- `total_steps`: int, optional - Total training steps (-1 means auto-calculate)
- Other scheduler-specific parameters (e.g., `min_lr`, `max_lr`, `decay_steps`, etc.)

Returns:
object: Built learning rate scheduler instance.

Raises:
ValueError: If warmup_epochs is not a non-negative integer.

The method performs the following operations:
- Calculates total training steps based on dataset size and training configuration
- Processes warmup settings (epochs or ratio) and converts warmup_epochs to warmup_steps
- Handles conflicts between warmup_epochs and warmup_ratio settings
- Builds the learning rate scheduler using the `build_lr` function
- Applies default warmup_lr_init value if not explicitly set
nn.learning_rate_schedule.LearningRateSchedule:
Built learning rate scheduler instance. The scheduler will dynamically adjust
the learning rate during training based on the current step. Returns None if
no valid configuration is provided (both lr_schedule_config and self.config.lr_schedule
are None or empty).

Note:
- Total steps calculation:
- If `sink_mode` is False: total_steps = epochs * train_dataset_size
- If `sink_mode` is True: total_steps = epochs * sink_size
- If `total_steps` is None or -1 in config, it will be auto-calculated
- `warmup_epochs` is converted to `warmup_steps` using: warmup_steps = warmup_epochs * train_dataset_size
- The method modifies the input config by popping `warmup_epochs` (if present)
- If `warmup_lr_init` is not set and the scheduler supports it, a default value will be used
- Supported scheduler types include: CosineWithWarmUpLR, PolynomialWithWarmUpLR,
ConstantWarmUpLR, CosineAnnealingLR, CosineWithRestartsAndWarmUpLR, etc.

Example:
Configuration example:
```
lr_schedule:
type: "CosineWithWarmUpLR"
learning_rate: 1e-4
warmup_ratio: 0.1
total_steps: -1 # Auto-calculate
```
"""
logger.info("Building learning rate scheduler from configuration...")
train_data_size = self._get_train_dataset_size(dataset)
train_dataset_size = self._get_train_dataset_size()
warmup_lr_init = None
lr_schedule_config = self.config.lr_schedule if lr_schedule_config is None else lr_schedule_config

if self.config.lr_schedule:
warmup_epochs = self.config.lr_schedule.pop("warmup_epochs", None)
warmup_lr_init = self.config.lr_schedule.get("warmup_lr_init", None)

if warmup_epochs is not None:
if not isinstance(warmup_epochs, int):
raise ValueError(f"The type of warmup_epochs must be int, but got type {type(warmup_epochs)}.")
if warmup_epochs < 0:
raise ValueError(f"The value of warmup_epochs must be non-negative integer, "
f"but got {warmup_epochs}.")
if lr_schedule_config:
warmup_lr_init = lr_schedule_config.get("warmup_lr_init", 0.)
warmup_steps = lr_schedule_config.get("warmup_steps", 0)
warmup_ratio = lr_schedule_config.get("warmup_ratio", 0.)
lr_schedule_config.warmup_steps = warmup_steps
lr_schedule_config.warmup_ratio = warmup_ratio

# Calculate total training steps
if not self.config.training_args.sink_mode:
total_steps = int(self.config.training_args.epochs * train_data_size)
total_steps = int(self.config.training_args.epochs * train_dataset_size)
else:
total_steps = int(self.config.training_args.epochs * self.config.training_args.sink_size)

# Handle conflicts between warmup_epochs and warmup_ratio
if warmup_epochs is not None and self.config.lr_schedule.warmup_ratio is not None:
logger.warning("warmup_epochs and warmup_ratio are set simultaneously, "
"warmup_ratio takes precedence.")
warmup_epochs = None

# Convert warmup_epochs to warmup_steps
if warmup_epochs is not None:
logger.info("warmup_epochs was set in lr_schedule, "
"converting to warmup_steps based on dataset size")
self.config.lr_schedule.warmup_steps = int(warmup_epochs * train_data_size)

# Set total_steps in lr_schedule
self.config.lr_schedule.total_steps = total_steps \
if self.config.lr_schedule.total_steps is None or self.config.lr_schedule.total_steps == -1 \
else int(self.config.lr_schedule.total_steps)
# Set total_steps in `lr_schedule` if not explicitly defined
if lr_schedule_config.total_steps is None or lr_schedule_config.total_steps == -1:
lr_schedule_config.total_steps = total_steps
else:
lr_schedule_config.total_steps = int(lr_schedule_config.total_steps)

# Build learning rate scheduler
lr_schedule = build_lr(self.config.lr_schedule)
lr_schedule = build_lr(lr_schedule_config)

# Apply default warmup_lr_init if not set
if lr_schedule and hasattr(lr_schedule, "warmup_lr_init") and warmup_lr_init is None:
@@ -1291,6 +1366,95 @@ class LLMTrainer:
logger.info("Learning rate scheduler created successfully")
return lr_schedule

def create_grouped_lr_scheduler(self) -> tuple[nn.learning_rate_schedule.LearningRateSchedule, list[dict]]:
"""Create grouped learning rate schedulers from configuration.

This method builds a set of learning rate schedulers for different parameter groups,
allowing fine-grained control over learning rates for different parts of the model.
It creates a default scheduler for unmatched parameters and multiple group-specific
schedulers based on parameter name patterns.

The configuration structure (`self.config.grouped_lr_schedule`) should contain:
- `default`: Configuration for the default learning rate scheduler (applied to all
parameters that don't match any group patterns)
- `grouped`: A list of dictionaries, each containing:
- `params`: List[str] - Parameter name patterns to match (supports wildcards via fnmatch)
- Other LR scheduler configuration keys (e.g., `type`, `learning_rate`, `warmup_steps`, etc.)

This is particularly useful for:
- Transfer learning: Different learning rates for pretrained vs. new layers
- Fine-tuning: Lower learning rates for embeddings, higher for task-specific layers
- Layer-wise decay: Gradually decreasing learning rates from top to bottom layers
- Component-specific schedules: Different schedules for attention, MLP, embeddings, etc.

Returns:
tuple[nn.learning_rate_schedule.LearningRateSchedule, list[dict]]:
A tuple containing:
- Default learning rate scheduler: Applied to parameters that don't match
any group patterns in the grouped configuration
- Grouped learning rate scheduler list: A list of dictionaries, each containing:
- 'params': List[str] - Parameter name patterns for this group
- 'lr_scheduler': LearningRateSchedule - The scheduler instance for this group
- 'lr_config': MindFormerConfig - The configuration used to create this scheduler

Raises:
ValueError: If any group configuration in `grouped` is missing the 'params' field
or if 'params' is empty.

Note:
- Parameter matching is done by name patterns, supporting wildcards (e.g., "*.embedding*")
- The default scheduler is always created and will be used for unmatched parameters
- Each group can have its own independent learning rate schedule configuration
- The method modifies the input configuration dictionaries by popping 'params'

Example:
Configuration structure:
```
grouped_lr_schedule:
default:
type: "ConstantWarmUpLR"
learning_rate: 1.e-4
warmup_ratio: 0.
total_steps: -1
grouped:
- params: ["embedding*"]
type: "CosineWithWarmUpLR"
learning_rate: 1.e-5
warmup_ratio: 0.
total_steps: -1
- params: ["*.self_attention*"]
type: "PolynomialWithWarmUpLR"
learning_rate: 2.e-4
warmup_ratio: 0.
total_steps: -1
```
"""
logger.info("Building grouped learning rate scheduler from configuration...")
default_lr_schedule_config = self.config.grouped_lr_schedule.default
default_lr_scheduler = self.create_lr_scheduler(default_lr_schedule_config)

grouped_lr_scheduler = []
grouped_config = self.config.grouped_lr_schedule.grouped

# Iterate over each grouped LR configuration
for lr_config in grouped_config:
params = lr_config.pop('params', None)
if not params or not isinstance(params, list):
raise ValueError(
"Got invalid 'params' in grouped_lr_schedule.grouped: each item must include "
"a non-empty 'params' list."
)

lr_config = MindFormerConfig(**lr_config)
lr_scheduler = self.create_lr_scheduler(lr_config)
grouped_lr_scheduler.append({
'params': params,
'lr_scheduler': lr_scheduler,
'lr_config': lr_config
})

return default_lr_scheduler, grouped_lr_scheduler

def create_model_wrapper(self, network: nn.Cell, optimizer: nn.Optimizer) -> Union[
MFPipelineWithLossScaleCell, MFTrainOneStepCell]:
"""Create model wrapper with training tools and configurations.
@@ -1339,6 +1503,8 @@ class LLMTrainer:
use_skip_data_by_global_norm=use_skip_data_by_global_norm,
global_norm_spike_threshold=global_norm_spike_threshold,
print_separate_loss=self.config.training_args.get("print_separate_loss", True),
lr_scheduler=self.lr_scheduler,
grouped_lr_scheduler=self.grouped_lr_scheduler,
)
logger.info("Created MFPipelineWithLossScaleCell model wrapper for pipeline parallel training")
else:
@@ -1355,6 +1521,8 @@ class LLMTrainer:
use_skip_data_by_global_norm=use_skip_data_by_global_norm,
global_norm_spike_threshold=global_norm_spike_threshold,
print_separate_loss=self.config.training_args.get("print_separate_loss", True),
lr_scheduler=self.lr_scheduler,
grouped_lr_scheduler=self.grouped_lr_scheduler,
)
logger.info("Created MFTrainOneStepCell model wrapper for standard training")

@@ -1489,7 +1657,7 @@ class LLMTrainer:
self._train_dataset_restore_from_checkpoint(train_dataset, load_checkpoint_path_or_file)

# Configure training steps and epochs
self._set_and_logging_training_step(train_dataset)
self._set_and_logging_training_step()

# Record configuration to global environment
self._record_config_to_global_envs()
@@ -1507,7 +1675,7 @@ class LLMTrainer:
self._count_parameters(network, run_mode=self.config.run_mode)

# Create optimizer and learning rate scheduler
optimizer = self.create_optimizer_scheduler(network, train_dataset)
optimizer = self.create_optimizer_scheduler(network)

# Wrap model with training tools
model_forward_and_backward_wrapper = self.create_model_wrapper(network, optimizer)
@@ -1807,7 +1975,7 @@ class LLMTrainer:
if units == 'M':
count_params_m = sum(total_params) / 1e6
trainable_params_m = sum(trainable_params) / 1e6
if run_mode in ['train', 'finetune']:
if run_mode in ['train', 'finetune']:
logger.info(f"Network Parameters: {count_params_m:.0f} M.")
logger.info(f"Network Trainable Parameters: {trainable_params_m:.0f} M.")
else:
@@ -1815,7 +1983,7 @@ class LLMTrainer:
elif units == 'B':
count_params_m = sum(total_params) / 1e9
trainable_params_m = sum(trainable_params) / 1e9
if run_mode in ['train', 'finetune']:
if run_mode in ['train', 'finetune']:
logger.info(f"Network Parameters: {count_params_m:.1f} B.")
logger.info(f"Network Trainable Parameters: {trainable_params_m:.1f} B.")
else:
@@ -1894,21 +2062,19 @@ class LLMTrainer:

raise ValueError(f"{checkpoint} does not exist, please check load_checkpoint in yaml and set a correct value.")

@staticmethod
def _get_train_dataset_size(train_dataset: GeneratorDataset) -> int:
def _get_train_dataset_size(self) -> int:
"""Get the size of the training dataset.

This static method retrieves the total number of samples in the
training dataset, which is used for calculating training steps
and epochs.

Args:
train_dataset (GeneratorDataset): Training dataset instance.

Returns:
int: Total number of samples in the training dataset.
"""
return train_dataset.get_dataset_size()
if self.train_dataset is None:
raise RuntimeError("Please set train_dataset in yaml.")
return self.train_dataset.get_dataset_size()

@staticmethod
def _check_auto_parallel_mode_valid() -> bool:


+ 8
- 3
mindformers/trainer/optimizer_grouped_parameters.py View File

@@ -48,7 +48,7 @@ def filter_current_stage_parameters(model, model_params):
param.requires_grad = False


def _get_gouped_lr_map(model, grouped_lr_scheduler=None):
def _get_grouped_lr_map(model, grouped_lr_scheduler=None):
"""
Build parameter-to-group and group-to-learning-rate mappings
based on grouped learning rate scheduler configuration.
@@ -78,7 +78,11 @@ def _get_gouped_lr_map(model, grouped_lr_scheduler=None):
GROUPED_PARAMS = [[] for _ in range(len(grouped_lr_scheduler))]

# Match actual parameter names to group patterns
for param in model.trainable_params():
trainable_params = model.trainable_params()
support_matching_param_names = json.dumps([param.name for param in trainable_params], indent=2)
logger.info(f"Support matching grouped parameter names: {support_matching_param_names}")

for param in trainable_params:
for grouped_param_name in list(group_map.keys()):
group_id = group_map.get(grouped_param_name)
# Match exact or wildcard parameter names
@@ -86,6 +90,7 @@ def _get_gouped_lr_map(model, grouped_lr_scheduler=None):
param_group_map[param.name] = group_id
GROUPED_PARAMS[group_id].append(param.name)
break

for group_id, sub_params in enumerate(GROUPED_PARAMS):
if not sub_params:
raise ValueError(
@@ -130,7 +135,7 @@ def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None,
filter_current_stage_parameters(model, model_params)

# Build mapping from params to LR groups
param_group_map, lr_scheduler_map = _get_gouped_lr_map(model, grouped_lr_schedule)
param_group_map, lr_scheduler_map = _get_grouped_lr_map(model, grouped_lr_schedule)
parameter_group_names = {} # For logging
parameter_group_vars = {} # Actual optimizer groups



+ 1
- 1
tests/st/test_ut/test_tools/test_template_v2/test_template_v2.py View File

@@ -984,7 +984,6 @@ class TestLrScheduleConfig:
default_config = LrScheduleConfig.default_value()
assert default_config['type'] == 'CosineWithWarmUpLR'
assert default_config['learning_rate'] == 5.e-5
assert default_config['lr_end'] == 0.
assert default_config['warmup_ratio'] == 0.
assert default_config['total_steps'] == -1

@@ -1697,6 +1696,7 @@ class TestConfigTemplate:
'model_config',
'optimizer',
'lr_schedule',
'grouped_lr_schedule',
'callbacks',
'monitor_config',
'profile',


Loading…
Cancel
Save
Baidu
map