#!/usr/bin/python #coding=utf-8 ''' If there are Chinese comments in the code,please add at the beginning: #!/usr/bin/python #coding=utf-8 示例选用的数据集是MnistDataset_torch.zip 数据集结构是: MnistDataset_torch.zip ├── test └── train 预训练模型文件夹结构是: Torch_MNIST_Example_Model ├── mnist_epoch1.pkl ''' import torch from model import Model import numpy as np from torchvision.datasets import mnist from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import argparse import os #导入c2net包 from c2net.context import prepare, upload_output import importlib.util def is_torch_dtu_available(): if importlib.util.find_spec("torch_dtu") is None: return False if importlib.util.find_spec("torch_dtu.core") is None: return False return importlib.util.find_spec("torch_dtu.core.dtu_model") is not None # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train') parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') if __name__ == '__main__': args, unknown = parser.parse_known_args() #初始化导入数据集和预训练模型到容器内 c2net_context = prepare() #获取数据集路径 MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" #获取预训练模型路径 Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"GCU_MNIST_Example_Model" #获取输出路径 output_path = c2net_context.output_path # load DPU envs-xx.sh DTU_FLAG = True if is_torch_dtu_available(): import torch_dtu import torch_dtu.distributed as dist import torch_dtu.core.dtu_model as dm from torch_dtu.nn.parallel import DistributedDataParallel as torchDDP print('dtu is available: True') device = dm.dtu_device() DTU_FLAG = True else: print('dtu is available: False') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") DTU_FLAG = False # 参数声明 model = Model().to(device) optimizer = SGD(model.parameters(), lr=1e-1) args, unknown = parser.parse_known_args() #log output batch_size = args.batch_size train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "train"), train=True, transform=ToTensor(),download=False) test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "test"), train=False, transform=ToTensor(),download=False) train_loader = DataLoader(train_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size) model = Model().to(device) sgd = SGD(model.parameters(), lr=1e-1) cost = CrossEntropyLoss() epochs = args.epoch_size print('epoch_size is:{}'.format(epochs)) # 如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")): checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.81.pkl")) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('加载 epoch {} 权重成功!'.format(start_epoch)) else: start_epoch = 0 print('无保存模型,将从头开始训练!') for _epoch in range(start_epoch, epochs): print('the {} epoch_size begin'.format(_epoch + 1)) model.train() for idx, (train_x, train_label) in enumerate(train_loader): train_x = train_x.to(device) train_label = train_label.to(device) label_np = np.zeros((train_label.shape[0], 10)) sgd.zero_grad() predict_y = model(train_x.float()) loss = cost(predict_y, train_label.long()) if idx % 10 == 0: print('idx: {}, loss: {}'.format(idx, loss.sum().item())) loss.backward() if DTU_FLAG: dm.optimizer_step(sgd, barrier=True) else: sgd.step() correct = 0 _sum = 0 model.eval() for idx, (test_x, test_label) in enumerate(test_loader): test_x = test_x test_label = test_label predict_y = model(test_x.to(device).float()).detach() predict_ys = np.argmax(predict_y.cpu(), axis=-1) label_np = test_label.numpy() _ = predict_ys == test_label correct += np.sum(_.numpy(), axis=-1) _sum += _.shape[0] print('accuracy: {:.2f}'.format(correct / _sum)) state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':_epoch+1} torch.save(state, '{}/mnist_epoch{}_{:.2f}.pkl'.format(output_path, _epoch+1, correct / _sum)) print(os.listdir('{}'.format(output_path)))