takeda_san’s blog

JavaFXとDeeplearning4jを頑張る方向。

サンプルプログラムをいじろう その2

前回のあらすじ

学習データを書き換えてうまくいくかと思ったが、そんなに甘くなかった。

takeda-san.hatenablog.com

やったこと

なぜ、出力がすべて1になったのか検討した。
まずはシグモイド関数を使ってるのに、出力に1以上の値を使っている点を直すべきだろう。
シグモイド関数は任意の実数を入力に、0から1の値を出力する。
出力の値を1000で割って1以下にしてみよう。
で、次が出力結果。

result0
 input  : 20,150,104.00
 output : 0.25
 answer : 0.25
result1
 input  : 20,150,106.00
 output : 0.25
 answer : 0.25
result2
 input  : 20,150,108.00
 output : 0.25
 answer : 0.25
result3
 input  : 20,150,108.00
 output : 0.25
 answer : 0.25

小数点以下2ケタまでしかなぜか出なくて詳しくはわからんが1以外の値が出た。
うまくいってるのかこれ。
どうやらgetDoubleやgetFloatでちゃんと出る模様。

result0
 input  : 2.0150104E7
 output : 0.24874994158744812
 answer : 0.25119999051094055
result1
 input  : 2.0150106E7
 output : 0.24874994158744812
 answer : 0.2467000037431717
result2
 input  : 2.0150108E7
 output : 0.24874994158744812
 answer : 0.24690000712871552
result3
 input  : 2.0150108E7
 output : 0.24874994158744812
 answer : 0.2502000033855438

毎回同じ数字やんけ。
学習データの出力を少し変えて、出力に変化がないか見てみる。

        INDArray    tIn     = Nd4j.create( new double[]{ 20150105 ,             // 入力1
                                                           20150106 ,             // 入力2
                                                           20150107 ,             // 入力3
                                                           20150108  },           // 入力4
                                           new int[]{ 4 , 1 } );            // サイズ
        INDArray    tOut    = Nd4j.create( new double[]{ 0.1512 , 0.8467 , 0.5469 , 0.7502} ,    // 出力1~4
result0
 input  : 2.0150104E7
 output : 0.6056897640228271
 answer : 0.15119999647140503
result1
 input  : 2.0150106E7
 output : 0.6056897640228271
 answer : 0.8467000126838684
result2
 input  : 2.0150108E7
 output : 0.6056897640228271
 answer : 0.5468999743461609
result3
 input  : 2.0150108E7
 output : 0.6056897640228271
 answer : 0.7501999735832214

同じやんけ。
入力値の違いが微々たるもの過ぎて、同じ数値だと勘違いしてるんかな?
営業日換算で1からカウントアップしてみよう。

        INDArray    tIn     = Nd4j.create( new double[]{ 1 ,             // 入力1
                                                           2 ,             // 入力2
                                                           3 ,             // 入力3
                                                           4  },           // 入力4
                                           new int[]{ 4 , 1 } );            // サイズ
        INDArray    tOut    = Nd4j.create( new double[]{ 0.2512 , 0.2467 , 0.2469 , 0.2502} ,    // 出力1~4
result0
 input  : 1.0
 output : 0.2483229637145996
 answer : 0.25119999051094055
result1
 input  : 2.0
 output : 0.24831531941890717
 answer : 0.2467000037431717
result2
 input  : 3.0
 output : 0.24835054576396942
 answer : 0.24690000712871552
result3
 input  : 4.0
 output : 0.24836908280849457
 answer : 0.2502000033855438

お、若干だけど数値が変わった。
変化が微妙すぎて違いが判らなかったということか?
ちなみに、学習回数を10倍の20000回にした結果。

result0
 input  : 1.0
 output : 0.24907267093658447
 answer : 0.25119999051094055
result1
 input  : 2.0
 output : 0.24896696209907532
 answer : 0.2467000037431717
result2
 input  : 3.0
 output : 0.2489119917154312
 answer : 0.24690000712871552
result3
 input  : 4.0
 output : 0.2488856166601181
 answer : 0.2502000033855438

それなりになってきたのかな?
学習回数こそパワーなのだろうか。
それって、脳みそ筋肉で良いと思います。

csvを読み込んでみる

それらしく動き出したので、扱うデータを増やしてみる。
csvを読み込むには次にようにするとよいらしい。 (手元の技術書から早くも仕様が変わった模様…)

// csvファイルから入力
RecordReader recordReader = new CSVRecordReader(0, ","); // 0行目から,をデミリタにして読む
recordReader.initialize(new FileSplit(new ClassPathResource("./9887_2015.csv").getFile())); // ファイルパスからcsvファイルを取得して初期化

int labelIndex = 0;  // 1カラム目が入力
int numClasses = 244;  // データ種類?
int batchSize = 50;  // バッチサイズ

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex , numClasses);
DataSet csvData = iterator.next();

System.out.println("label: " + csvData.getLabels().toString());
System.out.println("features: " + csvData.getFeatures().toString());

実行結果

label: [[0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
~長いので略~
features: [2,512.00, 2,467.00, 2,469.00, 2,502.00, 2,540.00, 2,542.00, 2,528.00, 2,539.00, 2,507.00, 2,509.00, 2,507.00, 2,507.00, 2,500.00, 2,457.00, 2,388.00, 2,420.00, 2,446.00, 2,471.00, 2,471.00, 2,472.00, 2,441.00, 2,467.00, 2,441.00, 2,448.00, 2,451.00, 2,462.00, 2,472.00, 2,450.00, 2,456.00, 2,458.00, 2,460.00, 2,495.00, 2,499.00, 2,499.00, 2,514.00, 2,504.00, 2,515.00, 2,500.00, 2,491.00, 2,460.00, 2,452.00, 2,454.00, 2,447.00, 2,453.00, 2,451.00, 2,453.00, 2,457.00, 2,453.00, 2,450.00, 2,443.00]

何かは読み込まれているが、これで合っているのだろうか…
labelIndexには入力の列数を入れるらしい。値は0から始まる。
numClassesには、データの分類数を指定せよとのことなので、今回はcsvの行数(244)を指定した。
batchSizeには、バッチサイズを指定。データがあまりにも大きい場合はバッチ処理としていくつかの区切りで学習を区切るらしい。
この値は適当。

データをさしかえて実行。

Exception in thread "main" org.deeplearning4j.exception.DL4JInvalidInputException: Labels array numColumns (size(1) = 244) does not match output layer number of outputs (nOut = 1)

あっ、あっ… 次回へ続く。

次回の予定

・サンプルプログラムをいじる
・courseraの動画を見る
・Deep Learning Javaプログラミングを読む
・プログラミングのための線形代数を読む
Mavenのpomの読み方