Rumah > pembangunan bahagian belakang > Tutorial Python > Rincian untuk Membina CNN Setara Biasa

Rincian untuk Membina CNN Setara Biasa

王林
Lepaskan: 2024-07-18 11:29:18
asal
1145 orang telah melayarinya

Satu prinsip hanya dinyatakan sebagai 'Biarkan kernel berputar' dan kami akan menumpukan dalam artikel ini tentang cara anda boleh menerapkannya dalam seni bina anda.

Seni bina setara membolehkan kami melatih model yang tidak peduli dengan tindakan kumpulan tertentu.

Untuk memahami maksud ini sebenarnya, mari kita latih model CNN mudah ini pada set data MNIST (set data digit tulisan tangan dari 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Salin selepas log masuk
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Jadual 1: Uji ketepatan model SimpleCNN

Seperti yang dijangkakan, kami mendapat lebih 95% ketepatan pada set data ujian, tetapi bagaimana jika kami memutarkan imej sebanyak 90 darjah? Tanpa sebarang tindakan balas yang dikenakan, keputusan menurun kepada hanya lebih baik sedikit daripada meneka. Model ini tidak berguna untuk aplikasi umum.

Sebaliknya, mari kita latih seni bina setara yang serupa dengan bilangan parameter yang sama, di mana tindakan kumpulan adalah tepat putaran 90 darjah.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Jadual 2: Uji ketepatan model EqCNN dengan jumlah parameter yang sama seperti model SimpleCNN

Ketepatan tetap sama dan kami tidak memilih untuk menambah data.

Model ini menjadi lebih mengagumkan dengan data 3D, tetapi kami akan tetap menggunakan contoh ini untuk meneroka idea teras.

Sekiranya anda ingin mengujinya sendiri, anda boleh mengakses semua kod yang ditulis dalam kedua-dua PyTorch dan JAX secara percuma di bawah Github-Repo, dan latihan dengan Docker atau Podman boleh dilakukan dengan hanya dua arahan.

Selamat mencuba!

Jadi Apakah Kesetaraan?

Senibina setara menjamin kestabilan ciri di bawah tindakan kumpulan tertentu. Kumpulan ialah struktur ringkas di mana elemen kumpulan boleh digabungkan, diterbalikkan atau tidak melakukan apa-apa.

Anda boleh mencari definisi rasmi di Wikipedia jika anda berminat.

Untuk tujuan kami, anda boleh memikirkan sekumpulan putaran 90 darjah yang bertindak pada imej segi empat sama. Kita boleh memutarkan imej sebanyak 90, 180, 270 atau 360 darjah. Untuk membalikkan tindakan, kami menggunakan putaran 270, 180, 90 atau 0 darjah masing-masing. Adalah mudah untuk melihat bahawa kita boleh menggabungkan, membalikkan atau melakukan apa-apa dengan kumpulan yang dilambangkan sebagai C4C_4C4 . Imej menggambarkan semua tindakan pada imej.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Rajah 1: Imej MNIST diputar sebanyak 90°, 180°, 270°, 360°, masing-masing

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
Salin selepas log masuk

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
Salin selepas log masuk

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
Salin selepas log masuk

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
Salin selepas log masuk

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Rajah 3: Peta ciri untuk keempat-empat putaran selepas imej input diputar

Saya mengekod warna peta yang sepadan. Setiap peta ciri dianjakkan oleh satu. Apabila operator maks akhir mengira hasil yang sama untuk peta ciri yang dialih ini, kami memperoleh hasil yang sama.

Dalam kod saya, saya tidak berputar kembali selepas lilitan terakhir, kerana kernel saya memekatkan imej kepada tatasusunan satu dimensi. Jika anda ingin mengembangkan contoh ini, anda perlu mengambil kira fakta ini.

Perakaunan untuk tindakan kumpulan atau "putaran kernel" memainkan peranan penting dalam reka bentuk seni bina yang lebih canggih.

Adakah ia Makan Tengah Hari Percuma?

Tidak, kami membayar dalam kelajuan pengiraan, bias induktif dan pelaksanaan yang lebih kompleks.

Perkara terakhir agak diselesaikan dengan perpustakaan seperti E3NN, di mana kebanyakan matematik berat diabstrakkan. Walau bagaimanapun, seseorang perlu mengambil kira banyak semasa reka bentuk seni bina.

Satu kelemahan cetek ialah 4x kos pengiraan untuk mengira semua lapisan ciri yang diputar. Walau bagaimanapun, perkakasan moden dengan paralelisasi jisim boleh mengatasi beban ini dengan mudah. Sebaliknya, melatih CNN mudah dengan penambahan data dengan mudah akan melebihi 10x dalam masa latihan. Ini menjadi lebih teruk lagi untuk putaran 3D yang mana penambahan data memerlukan kira-kira 500x jumlah latihan untuk mengimbangi semua putaran yang mungkin.

Secara keseluruhannya, reka bentuk model kesetaraan lebih kerap daripada bukan harga yang patut dibayar jika seseorang mahukan ciri yang stabil.

Apakah Seterusnya?

Reka bentuk model setara telah meletup dalam beberapa tahun kebelakangan ini, dan dalam artikel ini, kami hampir tidak mencalarkan permukaan. Malah, kami tidak mengeksploitasi sepenuhnya C4C_4C4 kumpulan lagi. Kami boleh menggunakan kernel 3D penuh. Walau bagaimanapun, model kami sudah mencapai ketepatan lebih 95%, jadi tiada sebab untuk pergi lebih jauh dengan contoh ini.

Selain CNN, penyelidik telah berjaya menterjemahkan prinsip ini kepada kumpulan berterusan, termasuk SO(2) JADI(2)JADI(2) (kumpulan semua putaran dalam satah) dan SE(3) SE(3)SE(3) (kumpulan semua terjemahan dan putaran dalam ruang 3D).

Menurut pengalaman saya, model ini benar-benar mengagumkan dan mencapai prestasi, apabila dilatih dari awal, setanding dengan prestasi model asas yang dilatih pada set data berbilang kali ganda lebih besar.

Beri tahu saya jika anda mahu saya menulis lebih lanjut mengenai topik ini.

Rujukan Lanjut

Sekiranya anda mahukan pengenalan rasmi kepada topik ini, berikut ialah kompilasi kertas kerja yang sangat baik, merangkumi sejarah lengkap kesetaraan dalam Pembelajaran Mesin.
AEN

Saya sebenarnya bercadang untuk membuat tutorial mendalam dan praktikal mengenai topik ini. Anda sudah boleh mendaftar untuk senarai mel saya dan saya akan memberikan anda versi percuma dari semasa ke semasa, bersama-sama saluran terus untuk maklum balas dan Soal Jawab.

Jumpa lagi :)

Atas ialah kandungan terperinci Rincian untuk Membina CNN Setara Biasa. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

sumber:dev.to
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan