機械学習開発に Go 言語を使用するにはどうすればよいですか?

WBOY
リリース: 2023-06-10 11:19:37
オリジナル
1234 人が閲覧しました

さまざまな分野で機械学習が広く応用されるにつれ、プログラマーは機械学習モデルを迅速かつ効果的に開発する方法にますます注目を集めています。 Python や R などの従来の機械学習言語は、機械学習の分野の標準ツールとなっていますが、Go 言語の同時実行性とパフォーマンスに魅了されるプログラマーが増えています。この記事では、機械学習開発に Go 言語を使用する方法について説明します。

  1. Go のインストール

まず、オペレーティング システムに Go をインストールする必要があります。 Go公式サイトからインストーラーをダウンロードしてインストールできます。インストールが完了したら、コマンド ラインで go version コマンドを実行して、Go が正しくインストールされているかどうかを確認します。

  1. 機械学習ライブラリをインストールする

Go には組み込みの機械学習ライブラリはありませんが、tensorflow、ゴルゴニア、ゴムルなど。ここでは、gorgonia を例として、Go を機械学習に使用する方法を紹介します。

コマンド ラインで次のコマンドを実行して、gorgonia をインストールします:

go get gorgonia.org/gorgonia
ログイン後にコピー

インストールが完了したら、次のコマンドを実行して、正しくインストールされているかどうかを確認できます:

package main

import "gorgonia.org/gorgonia"

func main() {
    gorgonia.NewGraph()
}
ログイン後にコピー

エラーが報告されない場合は、gorgonia が正常にインストールされたことを説明してください。

  1. Gorgonia の使用

次に、Gorgonia を使用して、手書き数字の画像を分類するための基本的なニューラル ネットワークを構築します。まず、データを準備する必要があります。 gorgenia には mnist データセットをダウンロードして解凍するために使用できる mnist パッケージがあります。

package main

import (
    "fmt"
    "gorgonia.org/datasets/mnist"
    "gorgonia.org/gorgonia"
)

func main() {
    // 下载和解压缩 mnist 数据集
    trainData, testData, err := mnist.Load(root)
    if err != nil {
        panic(err)
    }

    // 打印训练和测试数据及标签的形状
    fmt.Printf("train data shape: %v
", trainData.X.Shape())
    fmt.Printf("train labels shape: %v
", trainData.Y.Shape())
    fmt.Printf("test data shape: %v
", testData.X.Shape()) 
    fmt.Printf("test labels shape: %v
", testData.Y.Shape())
}
ログイン後にコピー

出力結果は次のとおりです。

train data shape: (60000, 28, 28, 1)
train labels shape: (60000, 10)
test data shape: (10000, 28, 28, 1)
test labels shape: (10000, 10)
ログイン後にコピー

トレーニング データには 60,000 個の 28x28 グレースケール画像が含まれており、テスト データには同じ形状の 10,000 個の画像が含まれています。各ラベルは、画像が属する番号を表す 10 次元のベクトルです。

次に、ニューラル ネットワークのアーキテクチャを定義します。 2 つの隠れ層を持つディープ ニューラル ネットワークを使用します。各隠れ層には 128 個のニューロンがあります。出力層で relu 活性化関数と Softmax 活性化関数を使用して画像を分類します。

dataShape := trainData.X.Shape()
dataSize := dataShape[0]
inputSize := dataShape[1] * dataShape[2] * dataShape[3]
outputSize := testData.Y.Shape()[1]

// 构建神经网络
g := gorgonia.NewGraph()
x := gorgonia.NewTensor(g, tensor.Float32, 4, gorgonia.WithShape(dataSize, dataShape[1], dataShape[2], dataShape[3]), gorgonia.WithName("x"))
y := gorgonia.NewMatrix(g, tensor.Float32, gorgonia.WithShape(dataSize, outputSize), gorgonia.WithName("y"))

hiddenSize := 128
hidden1 := gorgonia.Must(gorgonia.NodeFromAny(g, tensor.Zero(tensor.Float32, hiddenSize), gorgonia.WithName("hidden1")))
hidden2 := gorgonia.Must(gorgonia.NodeFromAny(g, tensor.Zero(tensor.Float32, hiddenSize), gorgonia.WithName("hidden2")))

w1 := gorgonia.NewMatrix(g, tensor.Float32, gorgonia.WithShape(inputSize, hiddenSize), gorgonia.WithName("w1"))
w2 := gorgonia.NewMatrix(g, tensor.Float32, gorgonia.WithShape(hiddenSize, hiddenSize), gorgonia.WithName("w2"))
w3 := gorgonia.NewMatrix(g, tensor.Float32, gorgonia.WithShape(hiddenSize, outputSize), gorgonia.WithName("w3"))

b1 := gorgonia.NewVector(g, tensor.Float32, gorgonia.WithShape(hiddenSize), gorgonia.WithName("b1"))
b2 := gorgonia.NewVector(g, tensor.Float32, gorgonia.WithShape(hiddenSize), gorgonia.WithName("b2"))
b3 := gorgonia.NewVector(g, tensor.Float32, gorgonia.WithShape(outputSize), gorgonia.WithName("b3"))

hidden1Dot, err1 := gorgonia.Mul(x, w1)
hidden1Add, err2 := gorgonia.BroadcastAdd(hidden1Dot, b1, []byte{0})
hidden1Activate := gorgonia.Must(gorgonia.Rectify(hidden1Add))

hidden2Dot, err3 := gorgonia.Mul(hidden1Activate, w2)
hidden2Add, err4 := gorgonia.BroadcastAdd(hidden2Dot, b2, []byte{0})
hidden2Activate := gorgonia.Must(gorgonia.Rectify(hidden2Add))

yDot, err5 := gorgonia.Mul(hidden2Activate, w3)
yAdd, err6 := gorgonia.BroadcastAdd(yDot, b3, []byte{0})
ySoftMax := gorgonia.Must(gorgonia.SoftMax(yAdd))
ログイン後にコピー

確率的勾配降下法 (SGD) 法を使用してモデルをトレーニングします。各エポックでは、トレーニング データをバッチに分割し、勾配を計算し、各バッチのパラメーターを更新します。

iterations := 10
batchSize := 32
learningRate := 0.01

// 定义代价函数(交叉熵)
cost := gorgonia.Must(gorgonia.Mean(gorgonia.Must(gorgonia.Neg(gorgonia.Must(gorgonia.HadamardProd(y, gorgonia.Must(gorgonia.Log(ySoftMax)))))))

// 定义优化器
optimizer := gorgonia.NewVanillaSolver(g, gorgonia.WithLearnRate(learningRate))

// 表示模型将进行训练
vm := gorgonia.NewTapeMachine(g)

// 进行训练
for i := 0; i < iterations; i++ {
    fmt.Printf("Epoch %d
", i+1)

    for j := 0; j < dataSize; j += batchSize {
        upperBound := j + batchSize
        if upperBound > dataSize {
            upperBound = dataSize
        }
        xBatch := trainData.X.Slice(s{j, upperBound})
        yBatch := trainData.Y.Slice(s{j, upperBound})

        if err := gorgonia.Let(x, xBatch); err != nil {
            panic(err)
        }
        if err := gorgonia.Let(y, yBatch); err != nil {
            panic(err)
        }

        if err := vm.RunAll(); err != nil {
            panic(err)
        }

        if err := optimizer.Step(gorgonia.NodesToValueGrads(w1, b1, w2, b2, w3, b3)); err != nil {
            panic(err)
        }
    }

    // 测试准确率
    xTest := testData.X
    yTest := testData.Y

    if err := gorgonia.Let(x, xTest); err != nil {
        panic(err)
    }
    if err := gorgonia.Let(y, yTest); err != nil {
        panic(err)
    }

    if err := vm.RunAll(); err != nil {
        panic(err)
    }

    predict := gorgonia.Must(gorgonia.Argmax(ySoftMax, 1))
    label := gorgonia.Must(gorgonia.Argmax(yTest, 1))

    correct := 0
    for i := range label.Data().([]float32) {
        if predict.Data().([]float32)[i] == label.Data().([]float32)[i] {
            correct++
        }
    }

    fmt.Printf("Accuracy: %v
", float32(correct)/float32(len(label.Data().([]float32))))
}
ログイン後にコピー

簡単な機械学習モデルの開発が完了しました。隠れ層を追加したり、さまざまなオプティマイザーを使用したりするなど、ニーズに応じて拡張および最適化できます。

  1. 概要

この記事では、機械学習開発に Go 言語を使用する方法について説明し、gorgonia と mnist データ セットを例として取り上げ、その方法を示しました。手書き数字の画像を分類するための基本的なニューラル ネットワークを構築します。 Go は機械学習の分野で選択される言語ではないかもしれませんが、同時実行性とパフォーマンスに優れた利点があり、シナリオによっては良い選択となる可能性があります。

以上が機械学習開発に Go 言語を使用するにはどうすればよいですか?の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

ソース:php.cn
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
最新の問題
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート
私たちについて 免責事項 Sitemap
PHP中国語ウェブサイト:福祉オンライン PHP トレーニング,PHP 学習者の迅速な成長を支援します!