前回のあらすじ
学習データを書き換えてうまくいくかと思ったが、そんなに甘くなかった。
やったこと
なぜ、出力がすべて1になったのか検討した。
まずはシグモイド関数を使ってるのに、出力に1以上の値を使っている点を直すべきだろう。
シグモイド関数は任意の実数を入力に、0から1の値を出力する。
出力の値を1000で割って1以下にしてみよう。
で、次が出力結果。
result0 input : 20,150,104.00 output : 0.25 answer : 0.25 result1 input : 20,150,106.00 output : 0.25 answer : 0.25 result2 input : 20,150,108.00 output : 0.25 answer : 0.25 result3 input : 20,150,108.00 output : 0.25 answer : 0.25
小数点以下2ケタまでしかなぜか出なくて詳しくはわからんが1以外の値が出た。
うまくいってるのかこれ。
どうやらgetDoubleやgetFloatでちゃんと出る模様。
result0 input : 2.0150104E7 output : 0.24874994158744812 answer : 0.25119999051094055 result1 input : 2.0150106E7 output : 0.24874994158744812 answer : 0.2467000037431717 result2 input : 2.0150108E7 output : 0.24874994158744812 answer : 0.24690000712871552 result3 input : 2.0150108E7 output : 0.24874994158744812 answer : 0.2502000033855438
毎回同じ数字やんけ。
学習データの出力を少し変えて、出力に変化がないか見てみる。
INDArray tIn = Nd4j.create( new double[]{ 20150105 , // 入力1 20150106 , // 入力2 20150107 , // 入力3 20150108 }, // 入力4 new int[]{ 4 , 1 } ); // サイズ INDArray tOut = Nd4j.create( new double[]{ 0.1512 , 0.8467 , 0.5469 , 0.7502} , // 出力1~4
result0 input : 2.0150104E7 output : 0.6056897640228271 answer : 0.15119999647140503 result1 input : 2.0150106E7 output : 0.6056897640228271 answer : 0.8467000126838684 result2 input : 2.0150108E7 output : 0.6056897640228271 answer : 0.5468999743461609 result3 input : 2.0150108E7 output : 0.6056897640228271 answer : 0.7501999735832214
同じやんけ。
入力値の違いが微々たるもの過ぎて、同じ数値だと勘違いしてるんかな?
営業日換算で1からカウントアップしてみよう。
INDArray tIn = Nd4j.create( new double[]{ 1 , // 入力1 2 , // 入力2 3 , // 入力3 4 }, // 入力4 new int[]{ 4 , 1 } ); // サイズ INDArray tOut = Nd4j.create( new double[]{ 0.2512 , 0.2467 , 0.2469 , 0.2502} , // 出力1~4
result0 input : 1.0 output : 0.2483229637145996 answer : 0.25119999051094055 result1 input : 2.0 output : 0.24831531941890717 answer : 0.2467000037431717 result2 input : 3.0 output : 0.24835054576396942 answer : 0.24690000712871552 result3 input : 4.0 output : 0.24836908280849457 answer : 0.2502000033855438
お、若干だけど数値が変わった。
変化が微妙すぎて違いが判らなかったということか?
ちなみに、学習回数を10倍の20000回にした結果。
result0 input : 1.0 output : 0.24907267093658447 answer : 0.25119999051094055 result1 input : 2.0 output : 0.24896696209907532 answer : 0.2467000037431717 result2 input : 3.0 output : 0.2489119917154312 answer : 0.24690000712871552 result3 input : 4.0 output : 0.2488856166601181 answer : 0.2502000033855438
それなりになってきたのかな?
学習回数こそパワーなのだろうか。
それって、脳みそ筋肉で良いと思います。
csvを読み込んでみる
それらしく動き出したので、扱うデータを増やしてみる。
csvを読み込むには次にようにするとよいらしい。
(手元の技術書から早くも仕様が変わった模様…)
// csvファイルから入力 RecordReader recordReader = new CSVRecordReader(0, ","); // 0行目から,をデミリタにして読む recordReader.initialize(new FileSplit(new ClassPathResource("./9887_2015.csv").getFile())); // ファイルパスからcsvファイルを取得して初期化 int labelIndex = 0; // 1カラム目が入力 int numClasses = 244; // データ種類? int batchSize = 50; // バッチサイズ DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex , numClasses); DataSet csvData = iterator.next(); System.out.println("label: " + csvData.getLabels().toString()); System.out.println("features: " + csvData.getFeatures().toString());
実行結果
label: [[0.00, 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00], ~長いので略~ features: [2,512.00, 2,467.00, 2,469.00, 2,502.00, 2,540.00, 2,542.00, 2,528.00, 2,539.00, 2,507.00, 2,509.00, 2,507.00, 2,507.00, 2,500.00, 2,457.00, 2,388.00, 2,420.00, 2,446.00, 2,471.00, 2,471.00, 2,472.00, 2,441.00, 2,467.00, 2,441.00, 2,448.00, 2,451.00, 2,462.00, 2,472.00, 2,450.00, 2,456.00, 2,458.00, 2,460.00, 2,495.00, 2,499.00, 2,499.00, 2,514.00, 2,504.00, 2,515.00, 2,500.00, 2,491.00, 2,460.00, 2,452.00, 2,454.00, 2,447.00, 2,453.00, 2,451.00, 2,453.00, 2,457.00, 2,453.00, 2,450.00, 2,443.00]
何かは読み込まれているが、これで合っているのだろうか…
labelIndexには入力の列数を入れるらしい。値は0から始まる。
numClassesには、データの分類数を指定せよとのことなので、今回はcsvの行数(244)を指定した。
batchSizeには、バッチサイズを指定。データがあまりにも大きい場合はバッチ処理としていくつかの区切りで学習を区切るらしい。
この値は適当。
データをさしかえて実行。
Exception in thread "main" org.deeplearning4j.exception.DL4JInvalidInputException: Labels array numColumns (size(1) = 244) does not match output layer number of outputs (nOut = 1)
あっ、あっ… 次回へ続く。
次回の予定
・サンプルプログラムをいじる
・courseraの動画を見る
・Deep Learning Javaプログラミングを読む
・プログラミングのための線形代数を読む
・Mavenのpomの読み方