#!/usr/bin/python #coding=utf-8 ''' If there are Chinese comments in the code,please add at the beginning: #!/usr/bin/python #coding=utf-8 In the training environment, (1)the code will be automatically placed in the /tmp/code directory, (2)the uploaded dataset will be automatically placed in the /tmp/dataset directory Note: the paths are different when selecting a single dataset and multiple datasets. (1)If it is a single dataset: if MnistDataset_torch.zip is selected, the dataset directory is /tmp/dataset/train, /dataset/test; The dataset structure of the single dataset in the training image in this example: tmp ├──dataset ├── test └── train If multiple datasets are selected, such as MnistDataset_torch.zip and checkpoint_epoch1_0.73.zip, the dataset directory is /tmp/dataset/MnistDataset_torch/train, /tmp/dataset/MnistDataset_torch/test and /tmp/dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl The dataset structure in the training image for multiple datasets in this example: tmp ├──dataset ├── MnistDataset_torch | ├── test | └── train └── checkpoint_epoch1_0.73 ├── mnist_epoch1_0.73.pkl (3)the model download path is under /tmp/output by default, please specify the model output location to /tmp/output, qizhi platform will provide file downloads under the /tmp/output directory. (4)If the pre-training model file is selected, the selected pre-training model path save as parameter ckpt_url; ''' from model import Model import numpy as np import torch from torchvision.datasets import mnist from torch.nn import CrossEntropyLoss from torch.optim import SGD from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import argparse import os # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') #The dataset location is placed under /dataset parser.add_argument('--traindata', default="/tmp/dataset/train" ,help='path to train dataset') parser.add_argument('--testdata', default="/tmp/dataset/test" ,help='path to test dataset') parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train') parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') #获取模型文件名称 parser.add_argument('--ckpt_url', default="", help='pretrain model path') # 参数声明 WORKERS = 0 # dataloder线程数 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = Model().to(device) optimizer = SGD(model.parameters(), lr=1e-1) cost = CrossEntropyLoss() # 模型训练 def train(model, train_loader, epoch): model.train() train_loss = 0 for i, data in enumerate(train_loader, 0): x, y = data x = x.to(device) y = y.to(device) optimizer.zero_grad() y_hat = model(x) loss = cost(y_hat, y) loss.backward() optimizer.step() train_loss += loss loss_mean = train_loss / (i+1) print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item())) # 模型测试 def test(model, test_loader, test_data): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for i, data in enumerate(test_loader, 0): x, y = data x = x.to(device) y = y.to(device) optimizer.zero_grad() y_hat = model(x) test_loss += cost(y_hat, y).item() pred = y_hat.max(1, keepdim=True)[1] correct += pred.eq(y.view_as(pred)).sum().item() test_loss /= (i+1) print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_data), 100. * correct / len(test_data))) def main(): # 如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(args.ckpt_url): checkpoint = torch.load(args.ckpt_url) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] print('加载 epoch {} 权重成功!'.format(start_epoch)) else: start_epoch = 0 print('无保存模型,将从头开始训练!') for epoch in range(start_epoch+1, epochs): train(model, train_loader, epoch) test(model, test_loader, test_dataset) # 保存模型 state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} torch.save(state, '/tmp/output/mnist_epoch{}.pkl'.format(epoch)) if __name__ == '__main__': args, unknown = parser.parse_known_args() #log output print('cuda is available:{}'.format(torch.cuda.is_available())) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batch_size = args.batch_size epochs = args.epoch_size train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False) test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False) train_loader = DataLoader(train_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size) main()