package us.hall.weka;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.Point;
import java.awt.Rectangle;
import java.awt.RenderingHints;
import java.awt.Window;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.awt.geom.AffineTransform;
import java.awt.image.BufferedImage;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Random;
import javax.swing.JFrame;
import javax.swing.JPanel;
import weka.classifiers.Classifier;
import weka.core.Instances;
import weka.core.SerializationHelper;
import weka.classifiers.evaluation.Evaluation;

public class LearningCurve {

	private final static int FRAME_WIDTH = 1024;
	private final static int FRAME_HEIGHT = 768;
	private final static int PANEL_HEIGHT = 300;
	private final static int INSET_LEFT = 5;
	private final static int INSET_RIGHT = 5;
	private final static int PANEL_WIDTH = FRAME_WIDTH - INSET_LEFT - INSET_RIGHT;
	private final JFrame PLOT = new JFrame("Learning Curve");
	
	public static void main(String[] args) {
		String model = null;
		String dataFile = null;
		
		try {
			weka.core.WekaPackageManager.loadPackages(false, true, false);
		}
		catch (Exception ex) { 
			ex.printStackTrace();
			return;
		}
		
		for (int i=0; i<args.length; i++) {
			if (args[i].equals("-l")) {
				model = args[i+1];
				i++;
				continue;
			}
			else {
				if (args[i].equals("-t")) {
					dataFile = args[i+1];
					i++;
					continue;
				}
			}
		}
		try {
			new LearningCurve(model,dataFile);
		}
		catch (Exception ex) { ex.printStackTrace(); }
	}
	
	public LearningCurve(String model, String dataFile) throws Exception {	
		PLOT.setLocation(0,0);
		PLOT.setSize(new Dimension(FRAME_WIDTH,FRAME_HEIGHT));
		PLOT.setResizable(true);
		PLOT.addWindowListener(new WindowAdapter() {
			public void windowClosing(WindowEvent evt) {
				System.exit(0);
			}
		});
		try {
			FileInputStream serialFileIn = new FileInputStream(model);
			ObjectInputStream serializedIn = SerializationHelper.getObjectInputStream(serialFileIn);
			Object o = serializedIn.readObject();
			Classifier classifier = (Classifier)o;
			Instances data = new weka.core.converters.ConverterUtils.DataSource(dataFile).getDataSet();
				
			data.setClassIndex(data.numAttributes() - 1);
			data.randomize(new Random(1));
			int stepSize = data.numInstances() / 25;
			Plotter plotter = new Plotter(PANEL_WIDTH, PANEL_HEIGHT, data.numInstances() / stepSize);
			plotter.init();
			positionWindow(PLOT);
			PLOT.getContentPane().add(plotter);
			PLOT.pack();
			PLOT.setVisible(true);	
			Random r = new Random(1);
			for (int i = stepSize; i < data.numInstances(); i += stepSize ) {
				Instances subset = new weka.core.Instances(data, 1, i);
				classifier.buildClassifier(subset);
				Evaluation evaluationObject = new Evaluation(subset);
				if (stepSize > 20) {
					evaluationObject.crossValidateModel(classifier, subset, 10, r);
				}
				else {
					evaluationObject.evaluateModel(classifier, subset);
				}
				double trainError = evaluationObject.errorRate();
				Instances testSubset = new weka.core.Instances(data, i+1, data.numInstances() - (i+1));
				Evaluation evaluationObjectTest = new Evaluation(testSubset);
				evaluationObjectTest.evaluateModel(classifier, testSubset);
				evaluationObjectTest.evaluateModel(classifier, testSubset);
				double testError = evaluationObjectTest.errorRate();
				plotter.update(trainError,testError,i/stepSize);
			}
		}
		catch (IOException ioex) { ioex.printStackTrace(); }
		catch (ClassNotFoundException cnfe) { cnfe.printStackTrace(); }
	}
	
	public static void positionWindow(Window w)
	{
		Dimension sSize = w.getToolkit().getScreenSize();	// Position the window
  	 	int sHeight = sSize.height;
  	 	int sWidth = sSize.width;
  	 	Dimension aSize = w.getSize();
  	 	int aHeight = aSize.height;
  	 	int aWidth = aSize.width;
 	 	w.setLocation((sWidth-aWidth)/2,(sHeight-aHeight)/2);
	}
}

class Plotter extends JPanel {

	// error is actually 0-1, we scale to 0-1000
	private static final int ERROR_MAX = 1000;
	private static final int ERROR_MIN = 0;
	static final Color paleGreen = new Color(0xccffcc);
	static final Color TRAIN_COLOR = Color.RED;
	static final Color TEST_COLOR = Color.BLUE;
	final private BufferedImage offscreen;
	final private Graphics2D og2;
	private int width;
	private int height;
	final private Point[] plotPts = new Point[2];
	final private Point[] pts = new Point[2];
	private long rounds;
	private int y0;
	private final static String[] errors = 
		new String[] {"0.0","0.1","0.2","0.3","0.4","0.5","0.6","0.7","0.8","0.9","1.0"};	
		
	public Plotter(int width, int height, int rounds) {
		this.width = width;
		this.height = height;
		this.rounds = rounds;
		y0 = height / 2;
		setPreferredSize(new Dimension(width,height));
		setBackground(Color.white);
		offscreen = new BufferedImage(width,height,BufferedImage.TYPE_INT_ARGB);
		og2 = offscreen.createGraphics();
		og2.setRenderingHint(RenderingHints.KEY_ANTIALIASING,
			RenderingHints.VALUE_ANTIALIAS_ON);	
		og2.setColor(paleGreen);
		og2.fillRect(0,0,width,height);
		og2.setColor(Color.black);
		for (int i=0;i<plotPts.length;i++) 
			plotPts[i] = new Point(-1,y0);			
	}
	
	public void init() {
		double y = 0.0d;
		og2.setColor(Color.black);
		for (String s : errors) {
			og2.drawString(s,12,toY(y)-5);
			y += 0.1d;
		}
		og2.setColor(Color.green);
		y = 0.0d;
		while (y < 1.0d) {
			og2.drawLine(15,toY(y),width,toY(y));
			y += .1d;
		}
		og2.setColor(TRAIN_COLOR);
		og2.fill(new Rectangle(width-50,10,10,10));
		og2.setColor(Color.BLACK);
		og2.drawString("Train",width-35,20);
		og2.setColor(TEST_COLOR);
		og2.fill(new Rectangle(width-50,25,10,10));
		og2.setColor(Color.BLACK);
		og2.drawString("Test",width-35,34);
		
		AffineTransform orig = og2.getTransform();
		og2.rotate(-Math.PI/2);
		og2.setColor(Color.BLACK);
		og2.drawString("ERROR",100,100);
		og2.setTransform(orig);
	}
	
	public void update(double trainError, double testError, int round) {
		boolean updated = false;
		if (plotPts[0].x == -1) {
			plotPts[0] = new Point(toX(round),toY(trainError));
			plotPts[1] = new Point(toX(round),toY(testError));
			return;
		}
		pts[0] = new Point(toX(round),toY(trainError));
		if (pts[0].x != plotPts[0].x || pts[0].y != plotPts[0].y) {
			og2.setColor(TRAIN_COLOR);
			og2.drawLine(plotPts[0].x, plotPts[0].y, pts[0].x, pts[0].y);
			plotPts[0] = pts[0];
			updated = true;
		}
		pts[1] = new Point(toX(round),toY(testError));
		if (pts[1].x != plotPts[1].x || pts[1].y != plotPts[1].y) {
			og2.setColor(TEST_COLOR);
			og2.drawLine(plotPts[1].x, plotPts[1].y, pts[1].x, pts[1].y);
			plotPts[1] = pts[1];
			updated = true;
		}
		if (updated)
			repaint();
	}
	
	private int toX(long round) {
		return (int)(round*width/rounds);
	}
	
	private int toY(double error) {
		return (int)(height-(error*height));
	}
	
	public void reset() {
		og2.setColor(paleGreen);
		og2.fillRect(0,0,width,height);
		og2.setColor(Color.lightGray);
		for (int i=0;i<plotPts.length;i++) {
			pts[i] = new Point(0,0);
			plotPts[i] = new Point(0,y0);
		}			
	}
	
	protected void paintComponent(Graphics g) {
		g.drawImage(offscreen,0,0,null);
	}
}