|
|
|
@@ -0,0 +1,236 @@ |
|
|
|
# Copyright 2025 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
"""Run Muon optimizer accuracy test with configurable parameters via args""" |
|
|
|
import argparse |
|
|
|
import numpy as np |
|
|
|
import mindspore as ms |
|
|
|
from mindspore import nn, Tensor |
|
|
|
|
|
|
|
from mindformers.core.context.build_context import build_context |
|
|
|
from mindformers.core.optim.muon import Muon |
|
|
|
|
|
|
|
np.random.seed(1024) |
|
|
|
|
|
|
|
# Test weight initialization - same as optimizer_util.py |
|
|
|
FC1_WEIGHT = np.array([[0.72346634, 0.95608497, 0.4084163, 0.18627149, |
|
|
|
0.6942514, 0.39767185, 0.24918061, 0.4548748], |
|
|
|
[0.7203382, 0.19086994, 0.76286614, 0.87920564, |
|
|
|
0.3169892, 0.9462494, 0.62827677, 0.27504718], |
|
|
|
[0.3544535, 0.2524781, 0.5370583, 0.8313121, |
|
|
|
0.6670143, 0.0488653, 0.62225235, 0.7546456], |
|
|
|
[0.17985944, 0.05106374, 0.31064633, 0.4863033, |
|
|
|
0.848814, 0.5523157, 0.20295663, 0.7213356]]).astype("float32") |
|
|
|
|
|
|
|
FC1_BIAS = np.array([0.79708564, 0.13728078, 0.66322654, 0.88128525]).astype("float32") |
|
|
|
|
|
|
|
FC2_WEIGHT = np.array([[0.8473515, 0.50923985, 0.42287776, 0.29769543]]).astype("float32") |
|
|
|
|
|
|
|
FC2_BIAS = np.array([0.09996348]).astype("float32") |
|
|
|
|
|
|
|
|
|
|
|
class MockTransformerConfig: |
|
|
|
"""Mock transformer config for testing Muon optimizer.""" |
|
|
|
def __init__(self): |
|
|
|
self.multi_latent_attention = True |
|
|
|
self.tensor_model_parallel_size = 1 |
|
|
|
self.data_parallel_size = 1 |
|
|
|
|
|
|
|
|
|
|
|
class MockModel: |
|
|
|
""" |
|
|
|
Mock model class that provides required interfaces for Muon optimizer. |
|
|
|
This simulates the model interface that Muon optimizer expects. |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
self.config = MockTransformerConfig() |
|
|
|
|
|
|
|
def get_gpt_transformer_config(self): |
|
|
|
"""Return transformer config.""" |
|
|
|
return self.config |
|
|
|
|
|
|
|
def make_model_muon_fns(self): |
|
|
|
"""Return muon split and merge functions.""" |
|
|
|
def muon_split_fn(param_name, tensor): # pylint: disable=unused-argument |
|
|
|
"""Split function - returns tensor as list.""" |
|
|
|
return [tensor] |
|
|
|
|
|
|
|
def muon_merge_fn(param_name, tensor_list): # pylint: disable=unused-argument |
|
|
|
"""Merge function - returns first tensor.""" |
|
|
|
return tensor_list[0] |
|
|
|
|
|
|
|
return muon_split_fn, muon_merge_fn |
|
|
|
|
|
|
|
# pylint: disable=unused-argument |
|
|
|
def apply_qk_clip_scaling(self, params, param_names, param_layer, logit_threshold, |
|
|
|
muon_split_fn, muon_merge_fn): |
|
|
|
"""Apply query-key clipping scaling.""" |
|
|
|
return [(0, params[0])] |
|
|
|
|
|
|
|
def get_param_layer_indices(self, params): |
|
|
|
"""Return layer indices for parameters.""" |
|
|
|
return {p.name: 0 for p in params} |
|
|
|
|
|
|
|
def get_muon_filter(self): |
|
|
|
"""Return filter function to determine which params use Muon.""" |
|
|
|
def muon_filter(param): |
|
|
|
# Apply Muon to weight parameters with 2D shape (not bias) |
|
|
|
return len(param.shape) == 2 and 'bias' not in param.name |
|
|
|
return muon_filter |
|
|
|
|
|
|
|
def get_tp_dims(self, params): |
|
|
|
"""Return tensor parallel dimensions.""" |
|
|
|
return tuple(-1 for _ in params) |
|
|
|
|
|
|
|
def get_op_groups_info(self, params, op): # pylint: disable=unused-argument |
|
|
|
"""Return optimizer parallel group info.""" |
|
|
|
ops = tuple(1 for _ in params) |
|
|
|
op_groups = tuple("" for _ in params) |
|
|
|
return ops, op_groups |
|
|
|
|
|
|
|
|
|
|
|
class FakeNet(nn.Cell): |
|
|
|
"""Build fake net for testing.""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.fc1 = nn.Dense(in_channels=8, out_channels=4, |
|
|
|
weight_init=Tensor(FC1_WEIGHT), |
|
|
|
bias_init=Tensor(FC1_BIAS)) |
|
|
|
self.fc2 = nn.Dense(in_channels=4, out_channels=1, |
|
|
|
weight_init=Tensor(FC2_WEIGHT), |
|
|
|
bias_init=Tensor(FC2_BIAS)) |
|
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.relu(self.fc1(x)) |
|
|
|
x = self.fc2(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class NetWithLoss(nn.Cell): |
|
|
|
"""Build net with loss.""" |
|
|
|
|
|
|
|
def __init__(self, network, loss_fn): |
|
|
|
super().__init__() |
|
|
|
self.network = network |
|
|
|
self.loss = loss_fn |
|
|
|
|
|
|
|
def construct(self, x, label): |
|
|
|
out = self.network(x) |
|
|
|
loss = self.loss(out, label) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def make_fake_data(): |
|
|
|
"""Make fake data for testing.""" |
|
|
|
data, label = [], [] |
|
|
|
for i in range(20): |
|
|
|
data.append(ms.Tensor(np.array(np.ones((2, 8)) * i, dtype=np.float32))) |
|
|
|
label.append(ms.Tensor(np.array(np.ones((2, 1)) * (i + 1), dtype=np.float32))) |
|
|
|
return data, label |
|
|
|
|
|
|
|
|
|
|
|
class MuonRunner: |
|
|
|
"""Class to manage Muon optimizer test and training.""" |
|
|
|
|
|
|
|
def __init__(self, args_from_parser): |
|
|
|
self.args = args_from_parser |
|
|
|
self.learning_rate = self.args.learning_rate |
|
|
|
self.weight_decay = self.args.weight_decay |
|
|
|
self.momentum = self.args.momentum |
|
|
|
self.nesterov = self.args.nesterov |
|
|
|
self.num_steps = self.args.num_steps |
|
|
|
|
|
|
|
def build_network(self): |
|
|
|
"""Build network with Muon optimizer.""" |
|
|
|
net = FakeNet() |
|
|
|
mock_model = MockModel() |
|
|
|
|
|
|
|
loss_fn = nn.L1Loss(reduction='mean') |
|
|
|
networkwithloss = NetWithLoss(net, loss_fn) |
|
|
|
networkwithloss.set_train() |
|
|
|
|
|
|
|
params = networkwithloss.trainable_params() |
|
|
|
|
|
|
|
# Create Muon optimizer |
|
|
|
optimizer = Muon( |
|
|
|
params=params, |
|
|
|
learning_rate=self.learning_rate, |
|
|
|
weight_decay=self.weight_decay, |
|
|
|
matched_adamw_rms=0.2, |
|
|
|
momentum=self.momentum, |
|
|
|
nesterov=self.nesterov, |
|
|
|
adamw_betas=(0.95, 0.95), |
|
|
|
adamw_eps=1e-8, |
|
|
|
model=mock_model, |
|
|
|
) |
|
|
|
|
|
|
|
return networkwithloss, optimizer, mock_model |
|
|
|
|
|
|
|
def run(self): |
|
|
|
"""Run the training with Muon optimizer.""" |
|
|
|
networkwithloss, optimizer, mock_model = self.build_network() |
|
|
|
trainonestepcell = nn.TrainOneStepCell(networkwithloss, optimizer) |
|
|
|
|
|
|
|
losses = [] |
|
|
|
data, label = make_fake_data() |
|
|
|
for i in range(self.num_steps): |
|
|
|
loss = trainonestepcell(data[i], label[i]) |
|
|
|
losses.append(loss.asnumpy()) |
|
|
|
|
|
|
|
# Save results |
|
|
|
output_dict = { |
|
|
|
"losses": np.array(losses), |
|
|
|
"num_muon_m": len(optimizer.muon_m), |
|
|
|
"num_moments1": len(optimizer.moments1), |
|
|
|
"num_moments2": len(optimizer.moments2), |
|
|
|
} |
|
|
|
|
|
|
|
# Save muon momentum values for weight parameters |
|
|
|
muon_filter = mock_model.get_muon_filter() |
|
|
|
# pylint: disable=protected-access |
|
|
|
for idx, param in enumerate(optimizer._parameters): |
|
|
|
if muon_filter(param): |
|
|
|
muon_m_value = optimizer.muon_m[idx].asnumpy() |
|
|
|
output_dict[f"muon_m_{idx}"] = muon_m_value |
|
|
|
|
|
|
|
np.savez(self.args.output_path, **output_dict) |
|
|
|
print(f"Results saved to {self.args.output_path}") |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser(description="Run Muon optimizer test") |
|
|
|
parser.add_argument("--learning_rate", type=float, default=0.02) |
|
|
|
parser.add_argument("--weight_decay", type=float, default=0.1) |
|
|
|
parser.add_argument("--momentum", type=float, default=0.95) |
|
|
|
parser.add_argument("--nesterov", type=lambda x: x.lower() == "true", default=True) |
|
|
|
parser.add_argument("--num_steps", type=int, default=20) |
|
|
|
parser.add_argument("--output_path", type=str, default="output_muon.npz") |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
# Set context |
|
|
|
build_context({"use_legacy": False, "use_parallel": True}) |
|
|
|
ms.set_deterministic(True) |
|
|
|
ms.set_context(mode=ms.GRAPH_MODE) |
|
|
|
ms.set_seed(42) |
|
|
|
|
|
|
|
# Run training |
|
|
|
runner = MuonRunner(args) |
|
|
|
runner.run() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
main() |