サンプルプログラムをいじろう その1
前回のあらすじ
多層パーセプトロンプログラムを読んだ。
やったこと
今回からサンプルプログラムを少しいじって、何かデータを解析してみる。
どうせやるなら心躍る題材をということで、株価をデータとして扱います。
2015年の株価のデータで学習して2016年のデータで答え合わせをします。
日付を入力して株価(その日の終値)を出力するようなものを想定。
株価のデータは、ここから拝借しました。
何と便利な世の中か。
株価データ・株主優待情報・先物データ・ランキングデータ・CSVダウンロード無料 | 株式投資メモ・株価データベース
早速、csvを読み込んで…と行きたいが、こういうのは、まず小さく動作が確認できる機能を作ってそこから大きくしたほうがよいだろう。
数日分のデータを読んで、数日分の株価を出力するプログラムを書こう。
というわけでデータセットだけを挿げ替えたプログラム。
public static void main(String[] args) throws Exception { // 変数定義 int seed = 123; // 乱数シード int iterations = 2000; // 学習の試行回数 int inputNum = 1; // 入力数 int middleNum = 10; // 隠れ層のニューロン数 int outputNum = 1; // 出力数 INDArray tIn = Nd4j.create( new float[]{ 20150105 , // 入力1 20150106 , // 入力2 20150107 , // 入力3 20150108 }, // 入力4 new int[]{ 4 , 1 } ); // サイズ INDArray tOut = Nd4j.create( new float[]{ 2512 , 2467 , 2469 , 2502} , // 出力1~4 new int[]{ 4 , 1 } ); // サイズ DataSet train = new DataSet( tIn , tOut ); // 入出力を対応付けたデータセット System.out.println( train ); // ニューラルネットワークを定義 MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .learningRate(0.01) .weightInit(WeightInit.SIZE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater( Updater.NONE ) .list() .layer(0, new DenseLayer.Builder() .nIn(inputNum) .nOut(middleNum) .activation("sigmoid").build()) .layer(1, new OutputLayer.Builder( LossFunctions.LossFunction.MSE ) .nIn(middleNum) .nOut(outputNum) .activation("sigmoid") .build()) .backprop(true).pretrain(false); // ニューラルネットワークを作成 MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork perceptron = new MultiLayerNetwork(conf); perceptron.init(); // 確認用のリスナーを追加 perceptron.setListeners( new ScoreIterationListener(1) ); // 学習(fit) perceptron.fit( train ); // パーセプトロンの使用 for( int i=0 ; i<train.numExamples() ; i++ ) { // i個目のサンプルについて、 INDArray input = train.get(i).getFeatureMatrix(); INDArray answer = train.get(i).getLabels(); INDArray output = perceptron.output( input , false ); System.out.println( "result" + i ); System.out.println( " input : " + input ); System.out.println( " output : " + output ); System.out.println( " answer : " + answer ); System.out.flush(); } }
動作結果
result0 input : 20,150,104.00 output : 1.00 answer : 2,512.00 result1 input : 20,150,106.00 output : 1.00 answer : 2,467.00 result2 input : 20,150,108.00 output : 1.00 answer : 2,469.00 result3 input : 20,150,108.00 output : 1.00 answer : 2,502.00
出力すべて1。
なんかそんな気はしてた。
楽して億万長者には、なれないのだ。
検討はまた次回。
次回の予定
・サンプルプログラムをいじる
・courseraの動画を見る
・Deep Learning Javaプログラミングを読む
・プログラミングのための線形代数を読む
・Mavenのpomの読み方