読者です 読者をやめる 読者になる 読者になる

takeda_san’s blog

JavaFXとDeeplearning4jを頑張る方向。

パーセプトロンプログラムの理解 その3

機械学習 Deeplearning4j Java

前回のあらすじ

学習データの生成方法部分を読んだ。

takeda-san.hatenablog.com

やったこと

なんとも本丸感がある、このワンライナー。 落ち着いてひとずつ確認していきます。

        // ニューラルネットワークを定義
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(iterations)
                .learningRate(0.01)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater( Updater.NONE )
                .list()
                .layer(0, new OutputLayer.Builder( LossFunctions.LossFunction.MSE )
                        .nIn(inputNum)
                        .nOut(outputNum)
                        .activation("sigmoid")
                        .build())
                .backprop(true).pretrain(false);

        // ニューラルネットワークを作成
        MultiLayerConfiguration conf        = builder.build();
        MultiLayerNetwork       perceptron  = new MultiLayerNetwork(conf);
        perceptron.init();

大きな流れとしては以下のようになっているようだ。

  1. NeuralNetConfiguration.Builder()でニューラルネットワークの設定し、buildメソッドで設定オブジェクトを生成
  2. 1.で作った設定を使用してモデルのインスタンスを生成
  3. 2.で作ったモデルを初期化する

MultiLayerConfiguration.Builderは、モデルの全体的な構造を定義するビルダークラス。
NeuralNetConfiguration.Builder()は、ニューラルネットワークのビルダーメソッド。
seed()は、乱数の種を指定するメソッド。
んー、なぜ乱数を生成する必要があるんだろうか…
別途調べよう。

深層学習(ディープラーニング)は自然現象であると考えると捉えやすい - WirelessWire News(ワイヤレスワイヤーニュース)

iterations()は、繰り返し回数を指定するメソッド。今回は1000回学習する。
learningRate()は、学習率を指定するメソッド。今回の学習率は0.01。
学習率とはモデルのパラメータを更新する幅らしい。
これも別途調べる。

weightInit()は、重みを指定するメソッド。
重みとは個々の神経細胞同士の繋がりの強さ(伝達物質の伝導しやすさ)を示す。らしい
これも…別途調べる。

optimizationAlgo()は、学習アルゴリズムを指定するメソッド。
確率的勾配降下法を使用する。
これも…・・・別途調べる。

updater()は、学習率の最適化など学習アルゴリズムの更新に使う。今回はNONEなので特に更新なし。
list()は、ニューラルネットワークの層の数を指定する。今回は1層なので指定なし。
 →入力層は数えない
layer()は、ニューラルネットワークのレイヤーを指定する。
 →活性化関数にシグモイド関数、誤差関数にMSE、入力が2つ、出力1つ(ORなので)
backprop()は、誤差逆伝搬の有無を指定するメソッド。
pretrain()は、事前学習の有無を指定するメソッド。

一通り技術書&JavaDocとにらめっこして調べてましたけど、下記記事に全部乗ってましたね。
ま、まぁ自分で調べることに意味があるから…

Java DeepLearning4j パラメータの設定|軽Lab

次回の予定

・不明点の確認(乱数を指定する意味、学習率とは、重みとは、確率的勾配降下法とは)
・プログラミングのための線形代数を読む
・サンプルプログラム(Perceptron/LenetMnistExample)を読む
Mavenのpomの読み方