• 技术文章 >后端开发 >Python教程

    Pytorch入门之mnist分类实例

    不言不言2018-04-14 16:00:57原创3556
    这篇文章主要为大家详细介绍了Pytorch入门之mnist分类实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

    本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    __author__ = 'denny'
    __time__ = '2017-9-9 9:03'
    
    import torch
    import torchvision
    from torch.autograd import Variable
    import torch.utils.data.dataloader as Data
    
    train_data = torchvision.datasets.MNIST(
     './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
    )
    test_data = torchvision.datasets.MNIST(
     './mnist', train=False, transform=torchvision.transforms.ToTensor()
    )
    print("train_data:", train_data.train_data.size())
    print("train_labels:", train_data.train_labels.size())
    print("test_data:", test_data.test_data.size())
    
    train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
    test_loader = Data.DataLoader(dataset=test_data, batch_size=64)
    
    
    class Net(torch.nn.Module):
     def __init__(self):
     super(Net, self).__init__()
     self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(1, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
     self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
     )
     self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
     )
     self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
     )
    
     def forward(self, x):
     conv1_out = self.conv1(x)
     conv2_out = self.conv2(conv1_out)
     conv3_out = self.conv3(conv2_out)
     res = conv3_out.view(conv3_out.size(0), -1)
     out = self.dense(res)
     return out
    
    
    model = Net()
    print(model)
    
    optimizer = torch.optim.Adam(model.parameters())
    loss_func = torch.nn.CrossEntropyLoss()
    
    for epoch in range(10):
     print('epoch {}'.format(epoch + 1))
     # training-----------------------------
     train_loss = 0.
     train_acc = 0.
     for batch_x, batch_y in train_loader:
     batch_x, batch_y = Variable(batch_x), Variable(batch_y)
     out = model(batch_x)
     loss = loss_func(out, batch_y)
     train_loss += loss.data[0]
     pred = torch.max(out, 1)[1]
     train_correct = (pred == batch_y).sum()
     train_acc += train_correct.data[0]
     optimizer.zero_grad()
     loss.backward()
     optimizer.step()
     print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
     train_data)), train_acc / (len(train_data))))
    
     # evaluation--------------------------------
     model.eval()
     eval_loss = 0.
     eval_acc = 0.
     for batch_x, batch_y in test_loader:
     batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
     out = model(batch_x)
     loss = loss_func(out, batch_y)
     eval_loss += loss.data[0]
     pred = torch.max(out, 1)[1]
     num_correct = (pred == batch_y).sum()
     eval_acc += num_correct.data[0]
     print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
     test_data)), eval_acc / (len(test_data))))

    相关推荐:

    python如何读取二进制mnist实例详解

    一篇不错的Python入门教程_python

    以上就是Pytorch入门之mnist分类实例的详细内容,更多请关注php中文网其它相关文章!

    声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn核实处理。
    专题推荐:Pytorch mnist 实例
    上一篇:Python简单实现控制电脑的方法 下一篇:Python简单计算文件MD5值的方法示例
    VIP课程(WEB全栈开发)

    相关文章推荐

    • 【腾讯云】年中优惠,「专享618元」优惠券!• Python自动化实践之筛选简历• 图文详解Python冒泡排序算法• Python 3.11中的最佳新功能和功能修复• Python接口自动化测试必备基础之http协议详解• 归纳总结Python函数进阶的使用方法
    1/1

    PHP中文网