takeda_san’s blog

JavaFXと機械学習を頑張る方向。

サンプルプログラムをいじろう その3

前回のあらすじ

csvからのデータ読み込みができない。

takeda-san.hatenablog.com

やったこと

参考になる情報がないなら公式にいけばいいじゃない。

dl4j-examples/CSVExample.java at master · deeplearning4j/dl4j-examples · GitHub

というわけで、csvを読み込んでいる、サンプルコードを見てみる。

//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();

iris.txtというのはよく使うサンプルデータセットらしい。

https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

3種類(Iris-setosa、Iris-versicolor、Iris-virginica)のデータが150行格納されている。
サンプルコードと照らし合わせてみよう。

int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row

iris.txtには1行あたり5つの値が含まれている。
5つめの値がラベルなので、0開始の5番目の値である4を指定する。

int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

iris.txtには3種類の値が含まれる。
この種類のことをクラスと呼ぶらしい。
なのでここの値は3。(これは0から数えないんだ…配列じゃないから?)

int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

iris.txtには150行のデータが格納されている。
batchSizeには一度に読み込むデータを指定できる。
今回はデータの数が少ないので一度にすべて読み込む。
なので、150。

なるほど、公式のサンプルが一番わかりやすい。
次から各種クラスの使い方に困ったら見にこよう。

つづいて、このルールに沿って前回のプログラムを修正してみる。

int labelIndex = 0;  // 1カラム目がラベル
int numClasses = 244;  //    データの種類は244
int batchSize = 244;  // バッチサイズ。一度にすべて読み込む

あれ?バッチサイズ以外は前回と変わらないような。

Exception in thread "main" org.deeplearning4j.exception.DL4JInvalidInputException: Invalid classification data: expect label value (at label index column = 0) to be in range 0 to 243 inclusive (0 to numClasses-1, with numClasses=244); got label value of 244

うーん。問題は別の場所にありそう。
とりあえず、CSVExample.javaを丸コピーで動かしてみよう。

Exception in thread "main" java.lang.NumberFormatException: For input string: "Iris-setosa"

データは数値型しか受け付けないようだ。
リポジトリをあさってみると、文字列を数値に置換したデータがあった。
3種類の文字列を0~2に置き換えているようだ。
これもなぜか0から始めないといけない模様。

deeplearning4j/iris.dat at master · deeplearning4j/deeplearning4j · GitHub

これにデータを置き換えて、再度実行。

===========INPUT===================
[[5.10, 3.50, 1.40, 0.20],
~長いので略~
 [5.90, 3.00, 5.10, 1.80]]
=================OUTPUT==================
[[1.00, 0.00, 0.00],
~長いので略~
 [0.00, 0.00, 1.00]]

お、いけてるっぽい。
次にこの謎の出力の読み方を考えてみる。
まずは、前回のプログラム上にデータセットを書いていたころの出力。

===========INPUT===================
[1.00, 2.00, 3.00, 4.00]
=================OUTPUT==================
[0.25, 0.25, 0.25, 0.25]

値が丸まっていて出力データが区別がつかないが、入力と出力が対応していることがわかる。
つぎに今回のirisデータセット

===========INPUT===================
[[5.10, 3.50, 1.40, 0.20],
 [4.90, 3.00, 1.40, 0.20],
 [4.70, 3.20, 1.30, 0.20],
 [4.60, 3.10, 1.50, 0.20],
 [5.00, 3.60, 1.40, 0.20],
~略~
=================OUTPUT==================
[[1.00, 0.00, 0.00],
 [1.00, 0.00, 0.00],
 [1.00, 0.00, 0.00],
 [1.00, 0.00, 0.00],
 [1.00, 0.00, 0.00],

この5行と対応するcsvの行が次のとおり。

5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
5.0,3.6,1.4,0.2,0

なるほど、1~4列目までが入力になっていると。
つまりCSVExample.javaでやりたいことは、1~4列目の4つの入力を使って、どのラベルが出力されるかを見ていたということらしい。
となると、今回の指定は入力と出力が逆となっていたようだ。
csvファイルが、1列目に営業日、2列目に株価なのでラベルは2列目ということになる。

ここまでくると、一度コードを全部捨ててCSVExample.javaを元にやり直したほうが良い気がしてきた。 一からの再出発は、また次回。

次回の予定

・サンプルプログラムをいじる
・courseraの動画を見る
・Deep Learning Javaプログラミングを読む
・プログラミングのための線形代数を読む
Mavenのpomの読み方