サンプルプログラムをいじろう その3
前回のあらすじ
csvからのデータ読み込みができない。
やったこと
参考になる情報がないなら公式にいけばいいじゃない。
dl4j-examples/CSVExample.java at master · deeplearning4j/dl4j-examples · GitHub
というわけで、csvを読み込んでいる、サンプルコードを見てみる。
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2 int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets) DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next();
iris.txtというのはよく使うサンプルデータセットらしい。
https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
3種類(Iris-setosa、Iris-versicolor、Iris-virginica)のデータが150行格納されている。
サンプルコードと照らし合わせてみよう。
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
iris.txtには1行あたり5つの値が含まれている。
5つめの値がラベルなので、0開始の5番目の値である4を指定する。
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
iris.txtには3種類の値が含まれる。
この種類のことをクラスと呼ぶらしい。
なのでここの値は3。(これは0から数えないんだ…配列じゃないから?)
int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
iris.txtには150行のデータが格納されている。
batchSizeには一度に読み込むデータを指定できる。
今回はデータの数が少ないので一度にすべて読み込む。
なので、150。
なるほど、公式のサンプルが一番わかりやすい。
次から各種クラスの使い方に困ったら見にこよう。
つづいて、このルールに沿って前回のプログラムを修正してみる。
int labelIndex = 0; // 1カラム目がラベル int numClasses = 244; // データの種類は244 int batchSize = 244; // バッチサイズ。一度にすべて読み込む
あれ?バッチサイズ以外は前回と変わらないような。
Exception in thread "main" org.deeplearning4j.exception.DL4JInvalidInputException: Invalid classification data: expect label value (at label index column = 0) to be in range 0 to 243 inclusive (0 to numClasses-1, with numClasses=244); got label value of 244
うーん。問題は別の場所にありそう。
とりあえず、CSVExample.javaを丸コピーで動かしてみよう。
Exception in thread "main" java.lang.NumberFormatException: For input string: "Iris-setosa"
データは数値型しか受け付けないようだ。
リポジトリをあさってみると、文字列を数値に置換したデータがあった。
3種類の文字列を0~2に置き換えているようだ。
これもなぜか0から始めないといけない模様。
deeplearning4j/iris.dat at master · deeplearning4j/deeplearning4j · GitHub
これにデータを置き換えて、再度実行。
===========INPUT=================== [[5.10, 3.50, 1.40, 0.20], ~長いので略~ [5.90, 3.00, 5.10, 1.80]] =================OUTPUT================== [[1.00, 0.00, 0.00], ~長いので略~ [0.00, 0.00, 1.00]]
お、いけてるっぽい。
次にこの謎の出力の読み方を考えてみる。
まずは、前回のプログラム上にデータセットを書いていたころの出力。
===========INPUT=================== [1.00, 2.00, 3.00, 4.00] =================OUTPUT================== [0.25, 0.25, 0.25, 0.25]
値が丸まっていて出力データが区別がつかないが、入力と出力が対応していることがわかる。
つぎに今回のirisデータセット。
===========INPUT=================== [[5.10, 3.50, 1.40, 0.20], [4.90, 3.00, 1.40, 0.20], [4.70, 3.20, 1.30, 0.20], [4.60, 3.10, 1.50, 0.20], [5.00, 3.60, 1.40, 0.20], ~略~ =================OUTPUT================== [[1.00, 0.00, 0.00], [1.00, 0.00, 0.00], [1.00, 0.00, 0.00], [1.00, 0.00, 0.00], [1.00, 0.00, 0.00],
この5行と対応するcsvの行が次のとおり。
5.1,3.5,1.4,0.2,0 4.9,3.0,1.4,0.2,0 4.7,3.2,1.3,0.2,0 4.6,3.1,1.5,0.2,0 5.0,3.6,1.4,0.2,0
なるほど、1~4列目までが入力になっていると。
つまりCSVExample.javaでやりたいことは、1~4列目の4つの入力を使って、どのラベルが出力されるかを見ていたということらしい。
となると、今回の指定は入力と出力が逆となっていたようだ。
csvファイルが、1列目に営業日、2列目に株価なのでラベルは2列目ということになる。
ここまでくると、一度コードを全部捨ててCSVExample.javaを元にやり直したほうが良い気がしてきた。 一からの再出発は、また次回。
次回の予定
・サンプルプログラムをいじる
・courseraの動画を見る
・Deep Learning Javaプログラミングを読む
・プログラミングのための線形代数を読む
・Mavenのpomの読み方