3 Commits

1 changed files with 112 additions and 0 deletions
Split View
  1. +112
    -0
      inference_for_c2net.py

+ 112
- 0
inference_for_c2net.py View File

@@ -0,0 +1,112 @@
#!/usr/bin/python
#coding=utf-8
'''
If there are Chinese comments in the code,please add at the beginning:
#!/usr/bin/python
#coding=utf-8

In the inference environment,
(1)the code will be automatically placed in the /tmp/code directory,
(2)the uploaded dataset will be automatically placed in the /tmp/dataset directory
Note: the paths are different when selecting a single dataset and multiple datasets.
(1)If it is a single dataset: if MnistDataset_torch.zip is selected,
the dataset directory is /tmp/dataset/train, /dataset/test;

The dataset structure of the single dataset in the training image in this example:
tmp
├──dataset
├── test
└── train

If multiple datasets are selected, such as MnistDataset_torch.zip and checkpoint_epoch1_0.73.zip,
the dataset directory is /tmp/dataset/MnistDataset_torch/train, /tmp/dataset/MnistDataset_torch/test
and /tmp/dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl
The dataset structure in the training image for multiple datasets in this example:
tmp
├──dataset
├── MnistDataset_torch
| ├── test
| └── train
└── checkpoint_epoch1_0.73
├── mnist_epoch1_0.73.pkl
(3)the ouput download path is under /tmp/output by default, please specify the output location to /tmp/output,
openi platform will provide file downloads under the /tmp/output directory.
(4)If the pre-training model file is selected, the selected pre-training model will be
automatically placed in the /tmp/pretrainmodel directory.
for example:
If the model file is selected, the calling method is: '/pretrainmodel/' + args.pretrainmodelname

In addition, if you want to get the model file after each training, you can call the uploader_for_gpu tool,
which is written as:
import os
os.system("cd /tmp/script_for_grampus/ &&./uploader_for_gpu " + "/tmp/output/")
'''


from model import Model
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import argparse
import os

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
#The dataset location is placed under /dataset
parser.add_argument('--traindata', default="/tmp/dataset/train" ,help='path to train dataset')
parser.add_argument('--testdata', default="/tmp/dataset/test" ,help='path to test dataset')
parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
#获取模型文件名称
parser.add_argument('--ckpt_url', default="", help='pretrain model path')

# 参数声明
WORKERS = 0 # dataloder线程数
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
optimizer = SGD(model.parameters(), lr=1e-1)
cost = CrossEntropyLoss()

# 模型测试
def test(model, test_loader, data_length):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
x, y = data
x = x.to(device)
y = y.to(device)
y_hat = model(x)
test_loss += cost(y_hat, y).item()
pred = y_hat.max(1, keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
test_loss /= (i+1)

# 结果写入输出文件夹
filename = 'result.txt'
file_path = os.path.join('/tmp/output', filename)
with open(file_path, 'w') as file:
file.write('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, data_length, 100. * correct / data_length))


if __name__ == '__main__':
args, unknown = parser.parse_known_args()
#log output
print('cuda is available:{}'.format(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
epochs = args.epoch_size
test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = Model().to(device)
checkpoint = torch.load(args.ckpt_url)
model.load_state_dict(checkpoint['model'])
test(model,test_loader,len(test_dataset))
os.system("cd /tmp/script_for_grampus/ &&./uploader_for_gpu " + "/tmp/output/")

Loading…
Cancel
Save
Baidu
map