ホームページ > バックエンド開発 > Python チュートリアル > 通常の等変 CNN を構築するための原則

通常の等変 CNN を構築するための原則

王林
リリース: 2024-07-18 11:29:18
オリジナル
1145 人が閲覧しました

その 1 つの原則は単に「カーネルを回転させる」と述べられており、この記事ではそれをアーキテクチャにどのように適用できるかに焦点を当てます。

等変アーキテクチャにより、特定のグループアクションに無関心なモデルをトレーニングできます。

これが正確に何を意味するのかを理解するために、この単純な CNN モデルを MNIST データセット (0 ~ 9 の手書き数字のデータセット) でトレーニングしてみましょう。

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
ログイン後にコピー
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

表 1: SimpleCNN モデルのテスト精度

予想通り、テスト データセットでは 95% 以上の精度が得られましたが、画像を 90 度回転したらどうなるでしょうか?何も対策を適用しない場合、結果は推測よりもわずかに優れたものになります。このモデルは一般的なアプリケーションには役に立ちません。

対照的に、グループアクションが正確に 90 度回転する、同じ数のパラメーターを使用して同様の等変アーキテクチャをトレーニングしてみましょう。

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

表 2: SimpleCNN モデルと同じ量のパラメーターを使用した EqCNN モデルのテスト精度

精度は同じままであり、データ拡張も選択しませんでした。

これらのモデルは 3D データを使用するとさらに印象的になりますが、核となるアイデアを探るためにこの例にこだわります。

自分でテストしてみたい場合は、Github-Repo から PyTorch と JAX の両方で書かれたすべてのコードに無料でアクセスでき、たった 2 つのコマンドで Docker または Podman を使用したトレーニングが可能です。

楽しんでください!

では、等分散とは何でしょうか?

等変アーキテクチャは、特定のグループアクションの下で機能の安定性を保証します。グループは単純な構造であり、グループ要素を結合したり、逆にしたり、何も行わなかったりすることができます。

興味があれば、Wikipedia で正式な定義を調べることができます。

私たちの目的としては、正方形の画像に作用する 90 度の回転のグループを考えることができます。画像を 90、180、270、または 360 度回転できます。アクションを逆にするには、それぞれ 270、180、90、または 0 度の回転を適用します。 で示されるグループを結合したり、逆にしたり、何もしないことができることは簡単にわかります。 C4C_4C4 。画像は、画像上のすべてのアクションを視覚化します。

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
図 1: それぞれ 90°、180°、270°、360° 回転させた MNIST 画像

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
ログイン後にコピー

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
ログイン後にコピー

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
ログイン後にコピー

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
ログイン後にコピー

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
図 3: 入力画像を回転した後の 4 つの回転すべての特徴マップ

対応するマップを色分けしました。各特徴マップは 1 つずつシフトされます。最後の max 演算子はこれらのシフトされた特徴マップに対して同じ結果を計算するため、同じ結果が得られます。

私のコードでは、カーネルが画像を 1 次元配列に圧縮するため、最後の畳み込み後に回転を戻しませんでした。この例をさらに詳しく説明したい場合は、この事実を考慮する必要があります。

グループ アクションまたは「カーネル ローテーション」を考慮することは、より洗練されたアーキテクチャの設計において重要な役割を果たします。

フリーランチですか?

いいえ、計算速度、帰納的バイアス、より複雑な実装で代償を支払います。

後者の点は、E3NN などのライブラリを使用するとある程度解決され、複雑な数学の大部分が抽象化されます。それにもかかわらず、アーキテクチャ設計時には多くのことを考慮する必要があります。

表面的な弱点の 1 つは、回転されたすべてのフィーチャ レイヤーを計算するのに 4 倍の計算コストがかかることです。ただし、大規模並列化を備えた最新のハードウェアは、この負荷に簡単に対処できます。対照的に、データ拡張を使用して単純な CNN をトレーニングすると、トレーニング時間はゆうに 10 倍を超えます。これは、すべての可能な回転を補正するためにデータ拡張に約 500 倍のトレーニング量が必要となる 3D 回転ではさらに悪化します。

全体として、安定した機能が必要な場合、等分散モデルの設計は支払う価値のある代償を払うことが多いです。

次は何ですか?

等価モデルの設計は近年爆発的に増加していますが、この記事では表面をなぞっただけです。実際、私たちはそのすべてを活用することさえできませんでした。 C4C_4C4 グループはまだ。完全な 3D カーネルを使用することもできました。ただし、私たちのモデルはすでに 95% 以上の精度を達成しているため、この例をさらに進める理由はほとんどありません。

CNN 以外にも、研究者たちはこれらの原則を以下のような継続的なグループにうまく応用しています。 SO(2) ソ(2)ソ(2) (平面内のすべての回転のグループ) および SE(3) SE(3)SE(3) (3D 空間内のすべての移動と回転のグループ)。

私の経験では、これらのモデルはまったく驚くべきものであり、ゼロからトレーニングした場合、数倍大きなデータセットでトレーニングされた基礎モデルのパフォーマンスに匹敵するパフォーマンスを達成します。

このトピックについてもっと書いてほしい場合は、お知らせください。

さらなる参考文献

このトピックへの正式な紹介が必要な場合は、機械学習における等変性の完全な歴史をカバーする優れた論文の編集版をここに示します。
あえん

私は実際に、このトピックに関する詳細な実践的なチュートリアルを作成する予定です。すでに私のメーリング リストにサインアップできます。フィードバックや Q&A のための直接チャネルとともに、無料バージョンを徐々に提供していきます。

また会いましょう :)

以上が通常の等変 CNN を構築するための原則の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

ソース:dev.to
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート