Python:PyTorch 分类 Fashion-MNIST 数据集 (七十八)

分类 Fashion-MNIST 数据集

分类 Fashion-MNIST 数据集

现在轮到你来构建一个神经网络了。你将使用的是 Fashion-MNIST 数据集,这是 MNIST 数据集的替代品。对于神经网络而言,原始的 MNIST 数据集体量太小,因而你可以轻易达到 97% 以上的准确率。而 Fashion-MNIST 数据集是一组有关衣物的 28x28 灰阶图像。这个数据集比 MNIST 复杂得多,因此你能更好地判断神经网络的性能,它也更加接近你在现实世界中使用的数据集。

file

在这个 notebook 中,你将构建专属于你的神经网络。在大多数情况下,你可以直接复制粘贴第三部分的代码,但这样一来你很难学到知识。因此我们推荐你自己编写代码来运行程序,这十分重要。不过在完成这个任务时,你也可以随时查阅和参考之前的 notebook。

首先,我们通过 torchvision 来加载数据集。

import torch
from torchvision import datasets, transforms
import helper

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Download and load the training data
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
print('trainloader-', trainloader)

print('trainloader-iter-', iter(trainloader))

print('trainloader-iter-', iter(trainloader))
trainloader- <torch.utils.data.dataloader.DataLoader object at 0x7f4816527c50>
trainloader-iter- <torch.utils.data.dataloader._DataLoaderIter object at 0x7f4818715710>
trainloader-iter- <torch.utils.data.dataloader._DataLoaderIter object at 0x7f4818715668>

在这里,我们能看到其中一张图片。

image, label = next(iter(trainloader))

print('image0=', image[0,:])
print('label=', label)
image0= tensor([[[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -0.9843,
          -1.0000, -1.0000, -1.0000, -0.6784, -1.0000, -1.0000, -1.0000,
          -0.9922, -0.9922, -0.9922, -0.9922, -0.9686, -0.9843, -0.9922],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -0.9922, -0.9843, -1.0000, -1.0000,
          -1.0000, -0.3961,  0.1294,  1.0000,  0.1373, -1.0000, -1.0000,
          -0.9765, -0.9922, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -0.9608, -1.0000, -1.0000, -1.0000, -0.4431,
           0.4431,  0.7569,  0.5686,  0.6706,  0.8275, -0.1608, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -0.4275,  0.3020,  0.7412,
           0.6471,  0.6706,  0.7725,  0.6549,  0.7725,  1.0000,  0.6471,
          -0.2784, -1.0000, -1.0000, -0.8902,  1.0000,  0.8745, -0.5451],
         [-0.9843, -1.0000, -1.0000, -0.9765, -0.9529, -0.9843, -1.0000,
          -1.0000, -1.0000, -0.0667,  0.5686,  0.6549,  0.8118,  0.4902,
           0.8980,  0.7490,  0.7647,  0.8510,  0.7882,  0.7804,  0.8588,
           1.0000,  0.9608,  0.6706,  1.0000,  0.7804,  0.8588,  0.2235],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -0.2941,  0.5137,  0.6078,  0.6706,  0.6078,  0.8510,  0.8118,
           0.7804,  0.8196,  0.8510,  0.7725,  0.7569,  0.8588,  0.8196,
           0.7882,  0.8353,  0.9529,  0.7569,  0.6235,  0.7804,  0.2235],
         [-1.0000, -1.0000, -1.0000, -0.9765, -0.5059, -0.0353,  0.6235,
           0.7020,  0.7412,  0.6784,  0.8431,  0.8353,  0.7569,  0.7725,
           0.7804,  0.7804,  0.7882,  0.8196,  0.8588,  0.8196,  0.8431,
           0.8588,  0.8118,  0.8431,  0.8118,  0.8431,  0.8353,  0.5529],
         [-1.0000, -0.3882,  0.4196,  0.6157,  0.7176,  0.7176,  0.6627,
           0.6392,  0.6863,  0.7255,  0.6471,  0.6549,  0.7255,  0.6941,
           0.8196,  0.7333,  0.6627,  0.7647,  0.7333,  0.7412,  0.7725,
           0.7725,  0.7882,  0.7569,  0.7020,  0.6392,  0.7725,  0.6157],
         [-0.9451,  0.1529,  0.5216,  0.6235,  0.6784,  0.7333,  0.7490,
           0.7490,  0.7569,  0.7490,  0.7882,  0.7882,  0.8275,  0.8353,
           0.8588,  0.8980,  0.8824,  0.9686,  0.9765,  0.9922,  0.9922,
           1.0000,  1.0000,  0.9216,  1.0000,  0.9686,  1.0000,  0.5373],
         [-0.9294, -0.6078, -0.4745, -0.2314,  0.0431,  0.2941,  0.4980,
           0.6078,  0.7725,  0.8118,  0.8039,  0.8118,  0.7804,  0.7804,
           0.7647,  0.7176,  0.6706,  0.6314,  0.5373,  0.4745,  0.4353,
           0.4275,  0.3725,  0.2000,  0.1529,  0.1059,  0.0353, -0.4353],
         [-0.8196, -0.6314, -0.7882, -0.8275, -0.8039, -0.8196, -0.7490,
          -0.6392, -0.3569,  0.0353, -0.0353, -0.0353, -0.0510, -0.1137,
          -0.1529, -0.1608, -0.2000, -0.2157, -0.2078, -0.2392, -0.2471,
          -0.2471, -0.2549, -0.2784, -0.2706, -0.2863, -0.3882, -0.5922],
         [-1.0000, -0.8745, -0.6941, -0.6627, -0.7098, -0.7725, -0.7882,
          -0.9216, -0.8196, -0.4588, -0.5137, -0.5294, -0.5529, -0.6314,
          -0.6784, -0.6706, -0.6863, -0.6863, -0.6627, -0.6941, -0.7020,
          -0.6941, -0.7020, -0.7490, -0.7961, -0.7961, -0.7647, -0.8745],
         [-1.0000, -1.0000, -1.0000, -1.0000, -0.9059, -0.9216, -0.8039,
          -0.7098, -0.7647, -0.8824, -0.8588, -0.8588, -0.9059, -0.8196,
          -0.7725, -0.8196, -0.8588, -0.8745, -0.9137, -0.9216, -0.9216,
          -0.9059, -0.9137, -0.9451, -0.9294, -0.9529, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
         [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
          -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000]]])
label= tensor([ 7,  9,  9,  3,  6,  4,  8,  8,  7,  5,  4,  7,  5,  0,
         7,  7,  2,  6,  8,  0,  2,  0,  2,  9,  7,  4,  3,  2,
         8,  6,  9,  5,  5,  8,  3,  2,  6,  1,  0,  1,  9,  3,
         6,  6,  3,  2,  9,  7,  3,  6,  9,  5,  2,  5,  6,  2,
         5,  4,  7,  5,  8,  2,  4,  8])
# show images
helper.imshow(image[0,:]);

file

在加载数据之后,我们应该导入一些必要的包了。

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import numpy as np
import time

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms

import helper

构建网络

在这里,你应该定义你的网络。如同 MNIST 数据集一样,这里的每张图片的像素为 28x28,共有 784 个像素点和 10 个类。你至少需要添加一个隐藏层。对于这些层,我们推荐你使用 ReLU 激活函数,并通过前向传播来返回 logits。层的数量和大小都由你来决定。

# TODO: Define your network architecture here
class Network(nn.Module):
    def __init__(self):
        # 初始化父类
        super().__init__()
        # Defining the layers, 200, 50, 10 units each
        self.fc1 = nn.Linear(784, 200)
        self.fc2 = nn.Linear(200, 50)

        # Output layer, 10 units - one for each digit
        self.fc3 = nn.Linear(50, 10)

    def forward(self, x):
        ''' Forward pass through the network, returns the output logits '''

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        return x

    def predict(self, x):
        ''' This function for predicts classes by calculating the softmax '''
        logits = self.forward(x)
        return F.softmax(logits, dim=1)

训练网络

现在,你应该构建你的网络并训练它了。首先,你需要定义条件(比如 nn.CrossEntropyLoss)以及优化器(比如 optim.SGDoptim.Adam)。

接着,你需要编写训练代码。请记住,训练传播是一个十分简明的过程:

  • 在网络中进行前向传播来获取 logits
  • 使用 logits 来计算损失
  • 使用 loss.backward() 在网络中进行后向传播来计算梯度
  • 使用优化器执行一个学习步来更新权重

通过调整超参数(隐藏单元、学习速率等),你应该可以将训练损失控制在 0.4 以下。

# TODO: Create the network, define the criterion and optimizer
net = Network()
criterion = nn.CrossEntropyLoss()
# lr = learning rate
optimizer = optim.SGD(net.parameters(), lr=0.01)
# TODO: Train the network here
print('Initial weights - ', net.fc1.weight)

print('print trainloader- ', trainloader)

dataiter = iter(trainloader)
images, labels = dataiter.next()

images.resize_(64, 784)

# Create Variables for the inputs and targets
inputs = Variable(images)
targets = Variable(labels)

# Clear the gradients from all Variables
optimizer.zero_grad()

# Forward pass, then backward pass, then update weights
output = net.forward(inputs)
loss = criterion(output, targets)
loss.backward()
print('Gradient -', net.fc1.weight.grad)
optimizer.step()
Initial weights -  Parameter containing:
tensor([[-6.5417e-03, -2.4802e-02, -1.5146e-02,  ..., -1.2971e-02,
         -1.6070e-02,  2.6286e-02],
        [ 5.7576e-03, -1.0399e-02,  4.4446e-03,  ..., -2.5082e-03,
         -4.3815e-03,  1.6680e-02],
        [-3.4878e-02, -2.0759e-02,  1.6003e-02,  ...,  8.6531e-04,
         -1.9558e-02,  2.1282e-02],
        ...,
        [ 2.6933e-02,  2.3026e-02, -3.3443e-02,  ..., -2.4827e-02,
         -3.5710e-02, -6.9900e-03],
        [ 3.5409e-03, -3.0244e-02,  9.5727e-03,  ..., -1.0316e-02,
         -1.9417e-02,  1.2862e-04],
        [-3.3738e-03, -3.0613e-02,  1.6543e-02,  ...,  1.7032e-02,
          2.3136e-02, -1.5136e-02]])
print trainloader-  <torch.utils.data.dataloader.DataLoader object at 0x7f4816527c50>
Gradient - tensor([[ 1.9819e-03,  1.9818e-03,  1.9818e-03,  ...,  1.9833e-03,
          1.9834e-03,  1.9818e-03],
        [-1.0067e-03, -1.0061e-03, -1.0094e-03,  ..., -1.0024e-03,
         -1.0067e-03, -1.0061e-03],
        [-1.8131e-03, -1.8103e-03, -1.8119e-03,  ..., -1.7803e-03,
         -1.8102e-03, -1.8103e-03],
        ...,
        [ 5.4580e-04,  5.4580e-04,  5.4289e-04,  ...,  5.5021e-04,
          5.4945e-04,  5.4580e-04],
        [ 1.3846e-04,  1.3887e-04,  1.3865e-04,  ...,  1.3844e-04,
          1.3593e-04,  1.3887e-04],
        [-7.1335e-04, -7.1335e-04, -7.1335e-04,  ..., -7.0985e-04,
         -7.1335e-04, -7.1335e-04]])
# 实际训练
epochs = 5
steps = 0
running_loss = 0
print_every = 40

for e in range(epochs):
    for images,labels in iter(trainloader):
        steps += 1
        # Flatten MNIST images into a 784 long vector
        images.resize_(images.size()[0], 784)

        # Wrap images and labels in Variables so we can calculate gradients
        inputs = Variable(images)
        targets = Variable(labels)
        optimizer.zero_grad()

        output = net.forward(inputs)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.data[0]

        if steps % print_every == 0:
            # Test accuracy
            accuracy = 0
            for ii, (images, labels) in enumerate(testloader):

                images = images.resize_(images.size()[0], 784)
                inputs = Variable(images, volatile=True)

                predicted = net.predict(inputs).data
                equality = (labels == predicted.max(1)[1])
                accuracy += equality.type_as(torch.FloatTensor()).mean()

            print("Epoch: {}/{}".format(e+1, epochs),
                  "Loss: {:.4f}".format(running_loss/print_every),
                  "Test accuracy: {:.4f}".format(accuracy/(ii+1)))
            running_loss = 0
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:23: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:31: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.

Epoch: 1/5 Loss: 2.2696 Test accuracy: 0.2491
Epoch: 1/5 Loss: 2.1425 Test accuracy: 0.4081
Epoch: 1/5 Loss: 1.9469 Test accuracy: 0.4537
Epoch: 1/5 Loss: 1.7148 Test accuracy: 0.5521
Epoch: 1/5 Loss: 1.4649 Test accuracy: 0.6200
Epoch: 1/5 Loss: 1.2773 Test accuracy: 0.6135
Epoch: 1/5 Loss: 1.1315 Test accuracy: 0.6715
Epoch: 1/5 Loss: 1.0100 Test accuracy: 0.6924
Epoch: 1/5 Loss: 0.9619 Test accuracy: 0.7131
Epoch: 1/5 Loss: 0.8913 Test accuracy: 0.7143
Epoch: 1/5 Loss: 0.8496 Test accuracy: 0.7067
Epoch: 1/5 Loss: 0.8156 Test accuracy: 0.7199
Epoch: 1/5 Loss: 0.7611 Test accuracy: 0.7322
Epoch: 1/5 Loss: 0.7401 Test accuracy: 0.7388
Epoch: 1/5 Loss: 0.7371 Test accuracy: 0.7433
Epoch: 1/5 Loss: 0.7028 Test accuracy: 0.7266
Epoch: 1/5 Loss: 0.7179 Test accuracy: 0.7473
Epoch: 1/5 Loss: 0.6788 Test accuracy: 0.7420
Epoch: 1/5 Loss: 0.6824 Test accuracy: 0.7507
Epoch: 1/5 Loss: 0.6545 Test accuracy: 0.7603
Epoch: 1/5 Loss: 0.6479 Test accuracy: 0.7602
Epoch: 1/5 Loss: 0.6444 Test accuracy: 0.7579
Epoch: 1/5 Loss: 0.6262 Test accuracy: 0.7652
Epoch: 2/5 Loss: 0.6556 Test accuracy: 0.7607
Epoch: 2/5 Loss: 0.6199 Test accuracy: 0.7681
Epoch: 2/5 Loss: 0.5855 Test accuracy: 0.7698
Epoch: 2/5 Loss: 0.6072 Test accuracy: 0.7705
Epoch: 2/5 Loss: 0.5928 Test accuracy: 0.7747
Epoch: 2/5 Loss: 0.6003 Test accuracy: 0.7717
Epoch: 2/5 Loss: 0.5904 Test accuracy: 0.7787
Epoch: 2/5 Loss: 0.5780 Test accuracy: 0.7787
Epoch: 2/5 Loss: 0.5626 Test accuracy: 0.7831
Epoch: 2/5 Loss: 0.5845 Test accuracy: 0.7889
Epoch: 2/5 Loss: 0.5900 Test accuracy: 0.7892
Epoch: 2/5 Loss: 0.5319 Test accuracy: 0.7854
Epoch: 2/5 Loss: 0.5623 Test accuracy: 0.7902
Epoch: 2/5 Loss: 0.5561 Test accuracy: 0.7937
Epoch: 2/5 Loss: 0.5558 Test accuracy: 0.7892
Epoch: 2/5 Loss: 0.5508 Test accuracy: 0.7948
Epoch: 2/5 Loss: 0.5395 Test accuracy: 0.7956
Epoch: 2/5 Loss: 0.5332 Test accuracy: 0.7911
Epoch: 2/5 Loss: 0.5422 Test accuracy: 0.8009
Epoch: 2/5 Loss: 0.5310 Test accuracy: 0.8002
Epoch: 2/5 Loss: 0.5016 Test accuracy: 0.8025
Epoch: 2/5 Loss: 0.5284 Test accuracy: 0.8026
Epoch: 2/5 Loss: 0.5133 Test accuracy: 0.7999
Epoch: 3/5 Loss: 0.5076 Test accuracy: 0.8068
Epoch: 3/5 Loss: 0.4983 Test accuracy: 0.8050
Epoch: 3/5 Loss: 0.4966 Test accuracy: 0.8049
Epoch: 3/5 Loss: 0.4893 Test accuracy: 0.7986
Epoch: 3/5 Loss: 0.5068 Test accuracy: 0.8085
Epoch: 3/5 Loss: 0.5018 Test accuracy: 0.8138
Epoch: 3/5 Loss: 0.5217 Test accuracy: 0.8073
Epoch: 3/5 Loss: 0.5191 Test accuracy: 0.8148
Epoch: 3/5 Loss: 0.5063 Test accuracy: 0.8132
Epoch: 3/5 Loss: 0.4712 Test accuracy: 0.8124
Epoch: 3/5 Loss: 0.4820 Test accuracy: 0.8118
Epoch: 3/5 Loss: 0.4893 Test accuracy: 0.8147
Epoch: 3/5 Loss: 0.5146 Test accuracy: 0.8163
Epoch: 3/5 Loss: 0.5124 Test accuracy: 0.8161
Epoch: 3/5 Loss: 0.4974 Test accuracy: 0.8133
Epoch: 3/5 Loss: 0.5093 Test accuracy: 0.8194
Epoch: 3/5 Loss: 0.4760 Test accuracy: 0.8176
Epoch: 3/5 Loss: 0.4960 Test accuracy: 0.8195
Epoch: 3/5 Loss: 0.4649 Test accuracy: 0.8154
Epoch: 3/5 Loss: 0.4778 Test accuracy: 0.8169
Epoch: 3/5 Loss: 0.5091 Test accuracy: 0.8137
Epoch: 3/5 Loss: 0.4302 Test accuracy: 0.8176
Epoch: 3/5 Loss: 0.4675 Test accuracy: 0.8230
Epoch: 3/5 Loss: 0.4825 Test accuracy: 0.8195
Epoch: 4/5 Loss: 0.4620 Test accuracy: 0.8243
Epoch: 4/5 Loss: 0.4741 Test accuracy: 0.8249
Epoch: 4/5 Loss: 0.4340 Test accuracy: 0.8180
Epoch: 4/5 Loss: 0.4573 Test accuracy: 0.8262
Epoch: 4/5 Loss: 0.4699 Test accuracy: 0.8250
Epoch: 4/5 Loss: 0.4633 Test accuracy: 0.8204
Epoch: 4/5 Loss: 0.4966 Test accuracy: 0.8244
Epoch: 4/5 Loss: 0.4854 Test accuracy: 0.8264
Epoch: 4/5 Loss: 0.4844 Test accuracy: 0.8253
Epoch: 4/5 Loss: 0.4672 Test accuracy: 0.8208
Epoch: 4/5 Loss: 0.4508 Test accuracy: 0.8267
Epoch: 4/5 Loss: 0.4514 Test accuracy: 0.8280
Epoch: 4/5 Loss: 0.4458 Test accuracy: 0.8267
Epoch: 4/5 Loss: 0.4318 Test accuracy: 0.8271
Epoch: 4/5 Loss: 0.4639 Test accuracy: 0.8262
Epoch: 4/5 Loss: 0.4509 Test accuracy: 0.8305
Epoch: 4/5 Loss: 0.4320 Test accuracy: 0.8266
Epoch: 4/5 Loss: 0.4579 Test accuracy: 0.8284
Epoch: 4/5 Loss: 0.4521 Test accuracy: 0.8237
Epoch: 4/5 Loss: 0.4405 Test accuracy: 0.8318
Epoch: 4/5 Loss: 0.4559 Test accuracy: 0.8295
Epoch: 4/5 Loss: 0.4785 Test accuracy: 0.8279
Epoch: 4/5 Loss: 0.4291 Test accuracy: 0.8318
Epoch: 5/5 Loss: 0.4580 Test accuracy: 0.8288
Epoch: 5/5 Loss: 0.4441 Test accuracy: 0.8292
Epoch: 5/5 Loss: 0.4358 Test accuracy: 0.8358
Epoch: 5/5 Loss: 0.4435 Test accuracy: 0.8337
Epoch: 5/5 Loss: 0.4557 Test accuracy: 0.8332
Epoch: 5/5 Loss: 0.4531 Test accuracy: 0.8322
Epoch: 5/5 Loss: 0.4062 Test accuracy: 0.8346
Epoch: 5/5 Loss: 0.4480 Test accuracy: 0.8331
Epoch: 5/5 Loss: 0.4449 Test accuracy: 0.8346
Epoch: 5/5 Loss: 0.4486 Test accuracy: 0.8310
Epoch: 5/5 Loss: 0.4481 Test accuracy: 0.8369
Epoch: 5/5 Loss: 0.4624 Test accuracy: 0.8359
Epoch: 5/5 Loss: 0.4464 Test accuracy: 0.8340
Epoch: 5/5 Loss: 0.4372 Test accuracy: 0.8350
Epoch: 5/5 Loss: 0.4079 Test accuracy: 0.8349
Epoch: 5/5 Loss: 0.3984 Test accuracy: 0.8368
Epoch: 5/5 Loss: 0.4247 Test accuracy: 0.8350
Epoch: 5/5 Loss: 0.4390 Test accuracy: 0.8332
Epoch: 5/5 Loss: 0.4108 Test accuracy: 0.8367
Epoch: 5/5 Loss: 0.4279 Test accuracy: 0.8362
Epoch: 5/5 Loss: 0.4078 Test accuracy: 0.8381
Epoch: 5/5 Loss: 0.4241 Test accuracy: 0.8380
Epoch: 5/5 Loss: 0.4210 Test accuracy: 0.8366
Epoch: 5/5 Loss: 0.4117 Test accuracy: 0.8317
# Test out your network!

dataiter = iter(testloader)
images, labels = dataiter.next()
img = images[0]
# Convert 2D image to 1D vector
img = img.resize_(1, 784)

# TODO: Calculate the class probabilities (softmax) for img
ps = net.predict(Variable(img.resize_(1, 784)))

# Plot the image and probabilities
helper.view_classify(img.resize_(1, 28, 28), ps)

file

训练好神经网络之后,你应该希望保存这个网络以便下次加载,而不是重新训练。很明显,每次使用时都重新训练网络并不现实。在实际操作中,你将会在训练网络之后将模型保存,接着重新加载网络以进行训练或是预测。在下一部分,我将为你展示如何保存和加载训练好的模型。

为者常成,行者常至