Penjelasan terperinci model LSTM dalam Python

王林
Lepaskan: 2023-06-10 12:57:24
asal
6054 orang telah melayarinya

LSTM ialah rangkaian neural berulang (RNN) khas yang boleh memproses dan meramal data siri masa. LSTM digunakan secara meluas dalam bidang seperti pemprosesan bahasa semula jadi, analisis audio dan ramalan siri masa. Artikel ini akan memperkenalkan prinsip asas dan butiran pelaksanaan model LSTM, dan cara menggunakan LSTM dalam Python.

1. Prinsip asas LSTM

Model LSTM terdiri daripada unit LSTM Setiap unit LSTM mempunyai tiga get: get input, get forget dan gate output, serta keadaan output. Input LSTM termasuk input pada saat semasa dan keadaan output pada saat sebelumnya. Tiga get dan keadaan keluaran dikira dan dikemas kini seperti berikut:

(1) Gerbang lupa: Kawal keadaan keluaran momen sebelumnya yang akan dilupakan Formula khusus adalah seperti berikut:

$f_t =sigma(W_f[h_{t-1},x_t]+b_f)$

Di mana, $h_{t-1}$ ialah keadaan keluaran momen sebelumnya, $x_t$ ialah input momen semasa , $W_f$ dan $b_f$ ialah pemberat dan pincang bagi get lupa, dan $sigma$ ialah fungsi sigmoid. $f_t$ ialah nilai dari 0 hingga 1, yang menunjukkan keadaan keluaran momen sebelumnya harus dilupakan.

(2) Gerbang input: Kawal input pada saat semasa akan ditambah kepada keadaan output Formula khusus adalah seperti berikut:

$i_t=sigma(W_i[h_{t. -1},x_t] +b_i)$

$ ilde{C_t}= anh(W_C[h_{t-1},x_t]+b_C)$

di mana, $i_t$ adalah dari 0 hingga 1 Nilai, menunjukkan input pada saat semasa harus ditambah pada keadaan output, $ilde{C_t}$ ialah keadaan memori sementara bagi input pada saat semasa.

(3) Keadaan kemas kini: Kira keadaan output dan keadaan sel pada saat semasa berdasarkan get lupa, get input dan keadaan memori sementara Formula khusus adalah seperti berikut:

$ C_t=f_t·C_{t -1}+i_t· ilde{C_t}$

$o_t=sigma(W_o[h_{t-1},x_t]+b_o)$

$h_t=o_t· anh( C_t)$

Di mana, $C_t$ ialah keadaan sel pada saat semasa, $o_t$ ialah nilai dari 0 hingga 1, menunjukkan keadaan sel mana yang sepatutnya dikeluarkan, $ h_t$ ialah keadaan keluaran pada saat semasa dan Nilai fungsi tanh bagi keadaan sel.

2. Butiran pelaksanaan LSTM

Model LSTM mempunyai banyak butiran pelaksanaan, termasuk pemulaan, fungsi kehilangan, pengoptimum, penormalan kelompok, pemberhentian awal, dsb.

(1) Permulaan: Parameter model LSTM perlu dimulakan dan anda boleh menggunakan nombor rawak atau parameter model pra-latihan. Parameter model LSTM termasuk berat dan berat sebelah, serta parameter lain seperti kadar pembelajaran, saiz kelompok dan bilangan lelaran.

(2) Fungsi kehilangan: Model LSTM biasanya menggunakan fungsi kehilangan entropi silang, yang mengukur perbezaan antara output model dan label sebenar.

(3) Pengoptimum: Model LSTM menggunakan kaedah penurunan kecerunan untuk mengoptimumkan fungsi kehilangan yang biasa digunakan termasuk kaedah keturunan kecerunan stokastik (RMSprop) dan pengoptimum Adam.

(4) Normalisasi kelompok: Model LSTM boleh menggunakan teknologi normalisasi kelompok untuk mempercepatkan penumpuan dan meningkatkan prestasi model.

(5) Pemberhentian awal: Model LSTM boleh menggunakan teknologi berhenti awal untuk menghentikan latihan apabila fungsi kehilangan tidak lagi bertambah baik pada set latihan dan set pengesahan untuk mengelakkan pemasangan berlebihan.

3. Pelaksanaan model LSTM dalam Python

Anda boleh menggunakan rangka kerja pembelajaran mendalam seperti Keras atau PyTorch untuk melaksanakan model LSTM dalam Python.

(1) Keras melaksanakan model LSTM

Keras ialah rangka kerja pembelajaran mendalam yang ringkas dan mudah digunakan yang boleh digunakan untuk membina dan melatih model LSTM. Berikut ialah contoh kod yang menggunakan Keras untuk melaksanakan model LSTM:

from keras.models import Sequential
from keras.layers import LSTM, Dense
from keras.utils import np_utils

model = Sequential()
model.add(LSTM(units=128, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(LSTM(units=64, return_sequences=True))
model.add(LSTM(units=32))
model.add(Dense(units=y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train, epochs=100, batch_size=256, validation_data=(X_test, y_test))
Salin selepas log masuk

(2) PyTorch melaksanakan model LSTM

PyTorch ialah rangka kerja pembelajaran mendalam untuk graf pengkomputeran dinamik yang boleh digunakan untuk membina dan melatih model LSTM. Berikut ialah contoh kod yang menggunakan PyTorch untuk melaksanakan model LSTM:

import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out

model = LSTM(input_size=X.shape[2], hidden_size=128, output_size=y.shape[1])
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 100
for epoch in range(num_epochs):
    outputs = model(X_train)
    loss = criterion(outputs, y_train.argmax(dim=1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
Salin selepas log masuk

4. Kesimpulan

LSTM ialah model rangkaian saraf berulang yang berkuasa yang boleh memproses dan meramal data siri masa dan secara meluas digunakan. Anda boleh menggunakan rangka kerja pembelajaran mendalam seperti Keras atau PyTorch untuk melaksanakan model LSTM dalam Python Dalam aplikasi praktikal, anda perlu memberi perhatian kepada butiran pelaksanaan seperti pemulaan parameter, fungsi kehilangan, pengoptimum, penormalan kelompok dan pemberhentian awal model.

Atas ialah kandungan terperinci Penjelasan terperinci model LSTM dalam Python. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Label berkaitan:
sumber:php.cn
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan