package us.hall.weka;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import weka.core.*;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.NeuralNetwork;

/**
 * Class to run an evaluation or do predictions for NeuralNetwork
 * https://github.com/amten/NeuralNetwork/releases/latest
 * written by Johannes Amten
 **/
public class RunNeuralNetwork {
	private Instances trainingData = null, testData = null;
	private Classifier classifier = null;
	private File test = null, training = null;
	private boolean evaluate = false;
	
	public static void main(String[] args) {
		File training = null;
		File test = null;
		boolean evaluate = false;
		
		for (int i=0;i<args.length;i++) {
			if (args[i].equals("-t")) {			// Q. training ARFF?
				i++;
				training = new File(args[i]);
			}
			else if (args[i].equals("-T")) {		// Q. test ARFF?
				i++;
				test = new File(args[i]);
			}	
			else if (args[i].equals("-e")) {		// Evaluate instead of predict?
				evaluate = true;
			}
		}
		if (training == null || !training.exists()) {
			System.out.println("Training arff dataset " + training + " is invalid or missing");
			return;
		}
		if (test == null || !test.exists()) {
			System.out.println("Test arff dataset " + test + " is invalid or missing");
			return;
		}
		RunNeuralNetwork rnn = new RunNeuralNetwork(training,test,evaluate);
		try {
			rnn.process();
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	public RunNeuralNetwork(File training,File test,boolean evaluate) {
		this.training = training;
		this.test = test;
		this.evaluate = evaluate;
		try {
			NeuralNetwork nn = new NeuralNetwork();			// Get the NeuralNetwork classifier
			classifier = nn;
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	private void process() throws Exception {
		BufferedReader trainingRdr = new BufferedReader(new FileReader(training));
		trainingData = new Instances(trainingRdr);
		trainingData.setClassIndex(trainingData.numAttributes() - 1);
		if (evaluate) {
			evaluate();
		}
		else {
			build();
			predict();
		}
	}
	
	private void build() throws Exception {
		classifier.buildClassifier(trainingData);
		trainingData = null;
	}
	
	private void evaluate() throws Exception {
		Evaluation eval = new Evaluation(trainingData);
		eval.crossValidateModel(classifier, trainingData, 10, new java.util.Random(1));
		System.out.println(eval.toSummaryString("\nResults\n======\n", false));
		System.out.println("");
		System.out.println(eval.toMatrixString());
		System.out.println("");
		System.out.println(eval.toClassDetailsString());
	}
	
	private void predict() throws Exception {
		BufferedReader testRdr = new BufferedReader(new FileReader(test));
		testData = new Instances(testRdr);
		testData.setClassIndex(testData.numAttributes() - 1);
		for (int i = 0; i < testData.numInstances(); i++) {
			int pred = (int)classifier.classifyInstance(testData.instance(i)) + 1;
			System.out.println(pred);
		}
	}
}