6 Commits

7 changed files with 222 additions and 182 deletions
Split View
  1. +4
    -5
      configs/qwen3/finetune_qwen3.yaml
  2. +4
    -5
      configs/qwen3/pretrain_qwen3_32b_4k.yaml
  3. +4
    -5
      configs/qwen3_moe/pretrain_qwen3_30b_a3b_4k.yaml
  4. +100
    -19
      tests/st/test_ut/test_dataset/test_dataloader/test_blended_megatron_dataset_builder.py
  5. +92
    -18
      tests/st/test_ut/test_dataset/test_dataloader/test_gpt_dataset.py
  6. +7
    -11
      tests/st/test_ut/test_megatron_format_checkpoint/test_metadata.py
  7. +11
    -119
      tests/st/test_ut/test_megatron_format_checkpoint/test_sharded_tensor.py

+ 4
- 5
configs/qwen3/finetune_qwen3.yaml View File

@@ -38,8 +38,8 @@ lr_schedule:

# Dataset configuration
train_dataset: &train_dataset
input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
input_columns: ["input_ids", "labels", "loss_mask", "position_ids"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"]

data_loader:
type: HFDataLoader
@@ -50,7 +50,7 @@ train_dataset: &train_dataset
split: "train"

# MindFormers dataset arguments
create_attention_mask: True
create_attention_mask: False
create_compressed_eod_mask: False
compressed_eod_mask_length: 128
use_broadcast_data: True
@@ -112,8 +112,7 @@ parallel:
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1, 1, 1]
[*dp, 1]
] # Must match the length of train_dataset.input_columns
search_mode: "sharding_propagation" # Fully-automatic parallel strategy search mode
strategy_ckpt_config:


+ 4
- 5
configs/qwen3/pretrain_qwen3_32b_4k.yaml View File

@@ -50,7 +50,7 @@ train_dataset: &train_dataset
seq_length: 4096 # Sequence length of the dataset
eod_mask_loss: False # Whether to calculate loss at the end-of-document (EOD)
reset_position_ids: False # Whether to reset position_ids at EOD
create_attention_mask: True # Whether to include attention_mask in the dataset
create_attention_mask: False # Whether to include attention_mask in the dataset
reset_attention_mask: False # Whether to reset attention_mask at EOD, creating a stepped attention_mask
create_compressed_eod_mask: False # Whether to include a compressed attention_mask
eod_pad_length: 128 # Length of the compressed attention_mask
@@ -59,8 +59,8 @@ train_dataset: &train_dataset
data_path: # Sampling proportion and path for the Megatron dataset
- '1'
- "/path/to/wiki103-megatron_text_document"
input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
input_columns: ["input_ids", "labels", "loss_mask", "position_ids"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"]
num_parallel_workers: 8
python_multiprocessing: False
drop_remainder: True
@@ -103,8 +103,7 @@ parallel:
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1, 1, 1]
[*dp, 1]
] # Must match the length of train_dataset.input_columns
search_mode: "sharding_propagation" # Fully-automatic parallel strategy search mode
strategy_ckpt_config:


+ 4
- 5
configs/qwen3_moe/pretrain_qwen3_30b_a3b_4k.yaml View File

@@ -54,7 +54,7 @@ train_dataset: &train_dataset
seq_length: 4096 # 数据集返回数据的序列长度
eod_mask_loss: False # 是否在eod处计算loss
reset_position_ids: False # 是否在eod处重置position_ids
create_attention_mask: True # 是否返回attention_mask
create_attention_mask: False # 是否返回attention_mask
reset_attention_mask: False # 是否在eod处重置attention_mask,返回阶梯状attention_mask
create_compressed_eod_mask: False # 是否返回压缩后的attention_mask
eod_pad_length: 128 # 设置压缩后attention_mask的长度
@@ -65,8 +65,8 @@ train_dataset: &train_dataset
- '1'
- "/path/to/wiki103-megatron_text_document"

input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
input_columns: ["input_ids", "labels", "loss_mask", "position_ids"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"]

num_parallel_workers: 8
python_multiprocessing: False
@@ -116,8 +116,7 @@ parallel:
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1, 1, 1]
[*dp, 1]
]
search_mode: "sharding_propagation"
enable_parallel_optimizer: False


+ 100
- 19
tests/st/test_ut/test_dataset/test_dataloader/test_blended_megatron_dataset_builder.py View File

@@ -14,6 +14,10 @@
# ============================================================================
"""Test cases for BlendedMegatronDatasetBuilder"""

import os
import subprocess
import time
import glob
from unittest.mock import patch
import pytest
import numpy as np
@@ -23,8 +27,87 @@ from mindformers.dataset.blended_datasets.blended_megatron_dataset_builder impor
_get_size_per_split_per_dataset
)
from mindformers.dataset.blended_datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from mindformers.dataset.blended_datasets.utils import Split, compile_helpers
from mindformers.dataset.blended_datasets.utils import Split
from mindformers.dataset.blended_datasets import utils as blended_utils_module
from mindformers.tools.logger import logger

try:
from filelock import FileLock

HAS_FILELOCK = True
except ImportError:
FileLock = None
HAS_FILELOCK = False


def _check_helpers_exists(helpers_dir):
"""Check if helpers.so exists and is valid."""
so_pattern = os.path.join(helpers_dir, "helpers*.so")
existing_so_files = glob.glob(so_pattern)
return existing_so_files and any(os.path.getsize(f) > 1000 for f in existing_so_files)


def _compile_helpers_safe(helpers_dir, worker_id):
"""Compile helpers if not already compiled."""
if _check_helpers_exists(helpers_dir):
return

logger.info(f"[{worker_id}] Starting compilation...")
result = subprocess.run(["make", "-C", helpers_dir], capture_output=True, text=True, check=False)

if result.returncode != 0 and not _check_helpers_exists(helpers_dir):
raise RuntimeError(f"Failed to compile helpers: {result.stderr}")

logger.info(f"[{worker_id}] Compilation completed")


@pytest.fixture(scope="session", autouse=True)
def ensure_helpers_compiled(request, tmp_path_factory):
"""Ensure helpers are compiled once with process-safe locking for pytest-xdist."""
# Get worker_id if running with pytest-xdist, otherwise use 'master'
worker_id = getattr(request.config, 'workerinput', {}).get('workerid', 'master')

helpers_dir = os.path.abspath(os.path.dirname(blended_utils_module.__file__))

# Quick check: if already compiled, all workers skip immediately
if _check_helpers_exists(helpers_dir):
logger.info(f"[{worker_id}] helpers.so already exists, using directly")
yield
return

# Single process mode - compile directly
if worker_id == "master":
_compile_helpers_safe(helpers_dir, worker_id)
yield
return

# Parallel mode - use file lock
lock_file = tmp_path_factory.getbasetemp().parent / "helpers_compile.lock"

if HAS_FILELOCK:
with FileLock(str(lock_file), timeout=300):
if not _check_helpers_exists(helpers_dir):
_compile_helpers_safe(helpers_dir, worker_id)
else:
# Fallback: simple atomic lock without filelock library
for _ in range(600): # 5 min timeout (600 * 0.5s)
try:
fd = os.open(str(lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
try:
if not _check_helpers_exists(helpers_dir):
_compile_helpers_safe(helpers_dir, worker_id)
finally:
os.close(fd)
os.unlink(str(lock_file))
break
except FileExistsError:
time.sleep(0.5)
if _check_helpers_exists(helpers_dir):
break
else:
raise TimeoutError("Timeout waiting for helpers compilation")

yield

class DummyTokenizer:
"""A dummy tokenizer for testing purposes"""
@@ -166,7 +249,7 @@ def create_test_builder(config_kwargs=None, builder_kwargs=None):
class TestBlendedMegatronDatasetBuilder:
"""Test class for BlendedMegatronDatasetBuilder"""

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_builder_initialization(self):
@@ -191,7 +274,7 @@ class TestBlendedMegatronDatasetBuilder:
assert builder.is_built_on_rank is is_built_on_rank_func
assert builder.config == config

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_builder_initialization_assertion_error(self):
@@ -213,7 +296,7 @@ class TestBlendedMegatronDatasetBuilder:
config=config
)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_size_per_split_per_dataset(self):
@@ -236,7 +319,7 @@ class TestBlendedMegatronDatasetBuilder:
assert result[0] == expected_0
assert result[1] == expected_1

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_generic_dataset(self):
@@ -270,7 +353,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(dataset, DummyMegatronDataset)
assert len(dataset) == num_samples

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_generic_dataset_with_oserror(self):
@@ -293,7 +376,7 @@ class TestBlendedMegatronDatasetBuilder:
True
)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_generic_dataset_distributed_rank_nonzero(self):
@@ -324,7 +407,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(dataset, DummyMegatronDataset)
assert len(dataset) == 5

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_generic_dataset_distributed_rank_zero_not_built(self):
@@ -354,7 +437,7 @@ class TestBlendedMegatronDatasetBuilder:

assert dataset is None

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_method_mock_dataset(self):
@@ -380,7 +463,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == len(Split)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_method_mock_dataset_failure(self):
@@ -427,7 +510,7 @@ class TestBlendedMegatronDatasetBuilder:
match="FailingDummyMegatronDataset failed to build as a mock data generator"):
builder.build()

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_blended_dataset_single_prefix(self):
@@ -455,7 +538,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == len(Split)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_with_blend_per_split(self):
@@ -484,7 +567,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == len(Split)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_with_blend_per_split_single_prefix(self):
@@ -513,7 +596,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == len(Split)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_with_blend_weights_and_size(self):
@@ -522,7 +605,6 @@ class TestBlendedMegatronDatasetBuilder:
Description: Test build method works with blend configuration having weights and size
Expectation: Method builds datasets correctly with weights processing
"""
compile_helpers()
config = create_test_config()
config.mock = False
config.blend = (["prefix1", "prefix2"], [0.3, 0.7])
@@ -542,7 +624,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == len(Split)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_verification_logic(self):
@@ -592,7 +674,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == 1

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_verification_logic_cached_no_check(self):
@@ -637,7 +719,7 @@ class TestBlendedMegatronDatasetBuilder:
assert isinstance(datasets, list)
assert len(datasets) == 1

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_build_parallel_datasets(self):
@@ -646,7 +728,6 @@ class TestBlendedMegatronDatasetBuilder:
Description: Test parallel building of megatron datasets
Expectation: Method builds datasets in parallel correctly
"""
compile_helpers()
config = create_test_config()
config.mock = False
config.blend = (["prefix1", "prefix2"], [0.5, 0.5])


+ 92
- 18
tests/st/test_ut/test_dataset/test_dataloader/test_gpt_dataset.py View File

@@ -15,8 +15,12 @@
"""test gpt dataset"""

import os
import subprocess
import tempfile
import time
import shutil
import glob

import pytest
import numpy as np

@@ -31,7 +35,87 @@ from mindformers.dataset.blended_datasets.gpt_dataset import (
_build_document_index,
_build_shuffle_index
)
from mindformers.dataset.blended_datasets.utils import Split, compile_helpers
from mindformers.dataset.blended_datasets.utils import Split
from mindformers.dataset.blended_datasets import utils as blended_utils_module
from mindformers.tools.logger import logger

try:
from filelock import FileLock

HAS_FILELOCK = True
except ImportError:
FileLock = None
HAS_FILELOCK = False


def _check_helpers_exists(helpers_dir):
"""Check if helpers.so exists and is valid."""
so_pattern = os.path.join(helpers_dir, "helpers*.so")
existing_so_files = glob.glob(so_pattern)
return existing_so_files and any(os.path.getsize(f) > 1000 for f in existing_so_files)


def _compile_helpers_safe(helpers_dir, worker_id):
"""Compile helpers if not already compiled."""
if _check_helpers_exists(helpers_dir):
return

logger.info(f"[{worker_id}] Starting compilation...")
result = subprocess.run(["make", "-C", helpers_dir], capture_output=True, text=True, check=False)

if result.returncode != 0 and not _check_helpers_exists(helpers_dir):
raise RuntimeError(f"Failed to compile helpers: {result.stderr}")

logger.info(f"[{worker_id}] Compilation completed")


@pytest.fixture(scope="session", autouse=True)
def ensure_helpers_compiled(request, tmp_path_factory):
"""Ensure helpers are compiled once with process-safe locking for pytest-xdist."""
# Get worker_id if running with pytest-xdist, otherwise use 'master'
worker_id = getattr(request.config, 'workerinput', {}).get('workerid', 'master')

helpers_dir = os.path.abspath(os.path.dirname(blended_utils_module.__file__))

# Quick check: if already compiled, all workers skip immediately
if _check_helpers_exists(helpers_dir):
logger.info(f"[{worker_id}] helpers.so already exists, using directly")
yield
return

# Single process mode - compile directly
if worker_id == "master":
_compile_helpers_safe(helpers_dir, worker_id)
yield
return

# Parallel mode - use file lock
lock_file = tmp_path_factory.getbasetemp().parent / "helpers_compile.lock"

if HAS_FILELOCK:
with FileLock(str(lock_file), timeout=300):
if not _check_helpers_exists(helpers_dir):
_compile_helpers_safe(helpers_dir, worker_id)
else:
# Fallback: simple atomic lock without filelock library
for _ in range(600): # 5 min timeout (600 * 0.5s)
try:
fd = os.open(str(lock_file), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
try:
if not _check_helpers_exists(helpers_dir):
_compile_helpers_safe(helpers_dir, worker_id)
finally:
os.close(fd)
os.unlink(str(lock_file))
break
except FileExistsError:
time.sleep(0.5)
if _check_helpers_exists(helpers_dir):
break
else:
raise TimeoutError("Timeout waiting for helpers compilation")

yield


class DummyTokenizer:
@@ -108,12 +192,7 @@ def create_test_dataset(config_kwargs=None, dataset_kwargs=None):
class TestGPTDatasetInitialization:
"""Test GPT dataset initialization"""

@classmethod
def setup_class(cls):
"""Setup class"""
compile_helpers()

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_gpt_dataset_real_initialization(self):
@@ -155,12 +234,7 @@ class TestGPTDatasetInitialization:
class TestMockGPTDatasetFunctionality:
"""Test Mock GPT dataset functionality"""

@classmethod
def setup_class(cls):
"""Setup class"""
compile_helpers()

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mock_gpt_dataset_configurations(self):
@@ -197,7 +271,7 @@ class TestMockGPTDatasetFunctionality:
assert len(padding_item) >= 4
assert np.all(padding_item[2] == 0)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mock_gpt_dataset_advanced_features(self):
@@ -221,7 +295,7 @@ class TestMockGPTDatasetFunctionality:
class TestGPTDatasetComponents:
"""Test GPT dataset components"""

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mock_gpt_low_level_dataset(self):
@@ -239,7 +313,7 @@ class TestGPTDatasetComponents:
sliced_item = mock_dataset.get(0, offset=0, length=10)
assert len(sliced_item) == 10

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_utility_functions(self):
@@ -289,7 +363,7 @@ class TestGPTDatasetComponents:
shuffle_idx = _build_shuffle_index(5, 5, numpy_random_state)
assert len(shuffle_idx) == 5

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_config_validation(self):
@@ -306,7 +380,7 @@ class TestGPTDatasetComponents:
tokenizer=DummyTokenizer()
)

@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cacheability_logic(self):


+ 7
- 11
tests/st/test_ut/test_megatron_format_checkpoint/test_metadata.py View File

@@ -20,7 +20,7 @@ import pytest

import mindspore as ms

from mindformers.checkpoint.sharded_tensor import get_sharded_tensor_list_from_strategy_metadata
from mindformers.checkpoint.sharded_tensor import get_sharded_tensor_from_strategy_metadata
from mindformers.checkpoint.metadata import save_metadata, load_metadata
from mindformers.checkpoint.utils import (
get_checkpoint_iter_dir,
@@ -54,25 +54,21 @@ NOT_EXISTS = False
def save_metadata_without_npu(global_strategy_info, model_keys, user_prefix, metadata_file_path, save_optimizer):
"""Saving metadata.json without NPU ranks, using mock data."""
npu_nums = 2
sharded_tensor_metas = list()
param_file_mappings = list()
sharded_tensor_metas = {}
param_file_mappings = []

for cur_npu_rank in range(0, npu_nums):
org_cur_rank_strategy_layout = global_strategy_info[cur_npu_rank]
cur_rank_strategy_layout = [
dict([item])
for item in org_cur_rank_strategy_layout.items()
]
cur_rank_strategy_layout = global_strategy_info[cur_npu_rank]

# Get Sharded tensors from strategy metadata of current rank.
cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata(
cur_rank_sharded_tensors = get_sharded_tensor_from_strategy_metadata(
param_infos=cur_rank_strategy_layout,
cur_npu_rank=cur_npu_rank,
filter_func=(lambda x: x in list(model_keys)) if not save_optimizer else None
)

# Get mappings of parameter file of current rank.
for sharded_tensor in cur_rank_sharded_tensors:
for _, sharded_tensor in cur_rank_sharded_tensors.items():
if save_optimizer and sharded_tensor.key not in list(model_keys):
ckpt_name = get_checkpoint_name(None, user_prefix, cur_npu_rank, npu_nums, FileType.OPTIMIZER)
else:
@@ -85,7 +81,7 @@ def save_metadata_without_npu(global_strategy_info, model_keys, user_prefix, met
)
)

sharded_tensor_metas.append(cur_rank_sharded_tensors)
sharded_tensor_metas[cur_npu_rank] = cur_rank_sharded_tensors

save_metadata(sharded_tensor_metas, param_file_mappings, metadata_file_path)



+ 11
- 119
tests/st/test_ut/test_megatron_format_checkpoint/test_sharded_tensor.py View File

@@ -24,10 +24,7 @@ from mindformers.checkpoint.sharded_tensor import (
ShardedTensor,
build_sharded_tensor,
is_main_replica,
get_sharded_tensor_list_from_cell,
convert_sharded_tensor_list_to_dict,
get_value_type_from_layout,
get_param_name_from_layout,
get_sharded_tensor_from_cell,
get_strategy_info_from_sharded_tensor,
_rank_id_with_slice_id,
_alias_name_with_rank_id,
@@ -158,11 +155,11 @@ class SimpleNet(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_sharded_tensor_list_from_cell():
def test_get_sharded_tensor_from_cell():
"""
Feature: get_sharded_tensor_list_from_cell function
Feature: get_sharded_tensor_from_cell function
Description: Extract sharded tensors from a neural network cell
Expectation: Returns list of ShardedTensor objects for cell parameters
Expectation: Returns dict of ShardedTensor objects for cell parameters
"""
net = SimpleNet()

@@ -170,41 +167,17 @@ def test_get_sharded_tensor_list_from_cell():
net.dense.weight.set_data(initializer(Normal(), net.dense.weight.shape, net.dense.weight.dtype))
net.dense.bias.set_data(initializer('zeros', net.dense.bias.shape, net.dense.bias.dtype))

sharded_tensors = get_sharded_tensor_list_from_cell(net)
sharded_tensors = get_sharded_tensor_from_cell(net)

assert len(sharded_tensors) >= 2 # Weight and bias

weight_tensor = next(t for t in sharded_tensors if 'weight' in t.key)
bias_tensor = next(t for t in sharded_tensors if 'bias' in t.key)
weight_tensor = next(t for t in sharded_tensors.values() if 'weight' in t.key)
bias_tensor = next(t for t in sharded_tensors.values() if 'bias' in t.key)

assert weight_tensor.local_shape == net.dense.weight.shape
assert bias_tensor.local_shape == net.dense.bias.shape


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_convert_sharded_tensor_list_to_dict():
"""
Feature: convert_sharded_tensor_list_to_dict function
Description: Convert list of ShardedTensor objects to dictionary
Expectation: Returns dictionary mapping tensor keys to ShardedTensor objects
"""
tensors = [
ShardedTensor(key=f"param_{i}", org_key="", dtype=ms.float32,
local_shape=(10,), global_shape=(100,),
global_offset=(i * 10,), axis_fragmentations=(10,))
for i in range(3)
]

tensor_dict = convert_sharded_tensor_list_to_dict(tensors)

assert len(tensor_dict) == 3
for i in range(3):
assert f"param_{i}" in tensor_dict
assert tensor_dict[f"param_{i}"].key == f"param_{i}"


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -286,11 +259,11 @@ def test_is_main_replica_single_nonzero_element_tuple():
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_sharded_tensor_list_from_cell_with_optimizer():
def test_get_sharded_tensor_from_cell_with_optimizer():
"""
Feature: get_sharded_tensor_list_from_cell function with optimizer
Feature: get_sharded_tensor_from_cell function with optimizer
Description: Extract sharded tensors from a neural network cell and optimizer
Expectation: Returns list of ShardedTensor objects for both cell and optimizer parameters
Expectation: Returns dict of ShardedTensor objects for both cell and optimizer parameters
"""
net = SimpleNet()

@@ -301,49 +274,12 @@ def test_get_sharded_tensor_list_from_cell_with_optimizer():
# Create optimizer
optim = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)

sharded_tensors = get_sharded_tensor_list_from_cell(net, optim)
sharded_tensors = get_sharded_tensor_from_cell(net, optim)

# Should have weight, bias from net and optimizer states (momentum, etc.)
assert len(sharded_tensors) >= 2


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_convert_empty_sharded_tensor_list_to_dict():
"""
Feature: convert_sharded_tensor_list_to_dict function
Description: Convert empty list of ShardedTensor objects to dictionary
Expectation: Returns empty dictionary
"""
tensor_dict = convert_sharded_tensor_list_to_dict([])
assert len(tensor_dict) == 0


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_convert_sharded_tensor_list_to_dict_duplicate_keys():
"""
Feature: convert_sharded_tensor_list_to_dict function
Description: Convert list of ShardedTensor objects with duplicate keys to dictionary
Expectation: Later tensor overwrites earlier one with same key
"""
tensors = [
ShardedTensor(key="param", org_key="", dtype=ms.float32,
local_shape=(10,), global_shape=(100,),
global_offset=(0,), axis_fragmentations=(10,)),
ShardedTensor(key="param", org_key="", dtype=ms.float32,
local_shape=(5,), global_shape=(50,),
global_offset=(10,), axis_fragmentations=(5,))
]

tensor_dict = convert_sharded_tensor_list_to_dict(tensors)

assert len(tensor_dict) == 1
assert tensor_dict["param"].local_shape == (5,)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@@ -458,47 +394,3 @@ def test_rank_id_with_slice_id():
assert isinstance(global_offset, tuple)
assert len(rank_slice_table) == 4 # 4 ranks
assert len(global_offset) == 4


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_param_name_from_layout():
"""
Feature: get_param_name_from_layout function
Description: Extract parameter names from layout information
Expectation: Returns list of parameter names
"""
param_infos = [
{
"weight": (None, None, None)
},
{
"bias": (None, None, None)
}
]

names = get_param_name_from_layout(param_infos)
assert names == ["weight", "bias"]


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_get_value_type_from_layout():
"""
Feature: get_value_type_from_layout function
Description: Extract parameter types from layout information
Expectation: Returns list of parameter types
"""
param_infos = [
{
"weight": (None, ms.float32, None)
},
{
"bias": (None, ms.float16, None)
}
]

types = get_value_type_from_layout(param_infos)
assert types == [ms.float32, ms.float16]

Loading…
Cancel
Save
Baidu
map