首頁 > 科技週邊 > 人工智慧 > 深入分析Pytorch核心要點,CNN解密!

深入分析Pytorch核心要點,CNN解密!

王林
發布: 2024-01-04 19:18:16
轉載
1321 人瀏覽過

哈嘍,我是小壯!

初學者對於創建卷積神經網路(CNN)可能不太熟悉,下面我們以一個完整的案例來進行說明。

CNN是廣泛應用於影像分類、目標偵測、影像生成等任務的深度學習模型。它透過卷積層和池化層自動提取影像的特徵,並透過全連接層進行分類。這種模型的關鍵在於利用捲積和池化的操作,有效地捕捉影像中的局部特徵,並透過多層網路進行組合,從而實現對影像的高級特徵提取和分類。

原理

1.卷積層(Convolutional Layer):

#卷積層透過卷積操作來提取輸入影像中的特徵。這個操作涉及一個可學習的捲積核,它在輸入影像上滑動併計算滑動視窗下的點積。這個過程有助於提取局部特徵,從而增強網路對平移不變性的感知能力。

公式:

突破Pytorch核心点,CNN !!!

其中,x是輸入,w是卷積核,b是偏移。

2.池化層(Pooling Layer):

池化層是一種常用的降維技術,其作用是減少資料的空間維度,從而降低計算量,並提取出最顯著的特徵。其中,最大池化是一種常見的池化方式,它會在每個視窗中選擇最大的值作為代表。透過最大池化,我們可以在保留重要資訊的同時,減少資料的複雜度,提高模型的運算效率。

公式(最大池化):

突破Pytorch核心点,CNN !!!

3.全連接層(Fully Connected Layer):

#全連接層在神經網路中扮演著重要的角色,它將捲積和池化層提取的特徵映射連接到輸出類別。全連接層的每個神經元都與前一層的所有神經元相連,這樣可以實現特徵的綜合和分類。

實戰步驟與詳解

1.步驟

  • 匯入必要的函式庫和模組。
  • 定義網路結構:使用nn.Module定義一個繼承自它的自訂神經網路類,定義卷積層、激活函數、池化層和全連接層。
  • 定義損失函數和最佳化器。
  • 載入和預處理資料。
  • 訓練網路:使用訓練資料迭代訓練網路參數。
  • 測試網路:使用測試資料評估模型效能。

2.程式碼實作

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))
登入後複製

這個範例展示了一個簡單的CNN模型,使用MNIST資料集進行訓練和測試。

接下來,咱們加入視覺化步驟,更直觀地了解模型的表現和訓練過程。

視覺化

1.導入matplotlib

import matplotlib.pyplot as plt
登入後複製

2.在訓練過程中記錄損失和準確率:

在訓練循環中,記錄每個epoch的損失和準確率。

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)
登入後複製

3.視覺化損失和準確率:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()
登入後複製

這樣,咱們就可以在訓練過程結束後看到訓練損失和準確率的變化。

匯入程式碼後,大家可以依照需求調整視覺化的內容和格式。

以上是深入分析Pytorch核心要點,CNN解密!的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:51cto.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板