/*
 * Decompiled with CFR 0.152.
 */
package jmaxent;

import java.io.PrintWriter;
import jmaxent.Feature;
import jmaxent.Model;
import jmaxent.Observation;
import org.riso.numerical.LBFGS;

public class Train {
    public Model model = null;
    public int numLabels = 0;
    public int numFeatures = 0;
    double[] lambda = null;
    double[] tempLambda = null;
    double[] gradLogLi = null;
    double[] diag = null;
    double[] temp = null;
    double[] ws = null;
    int[] iprint = null;
    int[] iflag = null;

    public void init() {
        this.numLabels = this.model.data.numLabels();
        this.numFeatures = this.model.feaGen.numFeatures();
        if (this.numLabels <= 0 || this.numFeatures <= 0) {
            System.out.println("Invalid number of labels or features");
            return;
        }
        this.lambda = this.model.lambda;
        this.tempLambda = new double[this.numFeatures];
        this.gradLogLi = new double[this.numFeatures];
        this.diag = new double[this.numFeatures];
        this.temp = new double[this.numLabels];
        int wsSize = this.numFeatures * (2 * this.model.option.mForHessian + 1) + 2 * this.model.option.mForHessian;
        this.ws = new double[wsSize];
        this.iprint = new int[2];
        this.iflag = new int[1];
    }

    public static double norm(double[] vect) {
        double res = 0.0;
        for (int i = 0; i < vect.length; ++i) {
            res += vect[i] * vect[i];
        }
        return Math.sqrt(res);
    }

    public void doTrain(PrintWriter fout) {
        int i;
        this.init();
        double f = 0.0;
        double xtol = 1.0E-16;
        int numIter = 0;
        this.iprint[0] = this.model.option.debugLevel - 2;
        this.iprint[1] = this.model.option.debugLevel - 1;
        this.iflag[0] = 0;
        for (i = 0; i < this.numFeatures; ++i) {
            this.lambda[i] = this.model.option.initLambdaVal;
        }
        System.out.println("Start to train ...");
        if (this.model.option.isLogging) {
            this.model.option.writeOptions(fout);
            fout.println("Start to train ...");
        }
        long start_train = System.currentTimeMillis();
        double maxAccuracy = 0.0;
        int maxAccuracyIter = -1;
        do {
            long start_iter = System.currentTimeMillis();
            f = this.computeLogLiGradient(this.lambda, this.gradLogLi, numIter + 1, fout);
            f *= -1.0;
            i = 0;
            while (i < this.numFeatures) {
                int n = i++;
                this.gradLogLi[n] = this.gradLogLi[n] * -1.0;
            }
            try {
                new LBFGS().lbfgs(this.numFeatures, this.model.option.mForHessian, this.lambda, f, this.gradLogLi, false, this.diag, this.iprint, this.model.option.epsForConvergence, xtol, this.iflag);
            }
            catch (LBFGS.ExceptionWithIflag e) {
                System.out.println("L-BFGS failed!");
                if (!this.model.option.isLogging) break;
                fout.println("L-BFGS failed!");
                break;
            }
            ++numIter;
            long end_iter = System.currentTimeMillis();
            long elapsed_iter = end_iter - start_iter;
            System.out.println("\tIteration elapsed: " + Double.toString((double)elapsed_iter / 1000.0) + " seconds");
            if (this.model.option.isLogging) {
                fout.println("\tIteration elapsed: " + Double.toString((double)elapsed_iter / 1000.0) + " seconds");
            }
            if (!this.model.option.evaluateDuringTraining) continue;
            this.model.doInference(this.model.data.tstData);
            double accuracy = this.model.evaluation.evaluate(fout);
            if (accuracy > maxAccuracy) {
                maxAccuracy = accuracy;
                maxAccuracyIter = numIter;
                if (this.model.option.saveBestModel) {
                    for (i = 0; i < this.numFeatures; ++i) {
                        this.tempLambda[i] = this.lambda[i];
                    }
                }
            }
            System.out.println("\tCurrent max accuracy: " + Double.toString(maxAccuracy) + " (at iteration " + Integer.toString(maxAccuracyIter) + ")");
            if (this.model.option.isLogging) {
                fout.println("\tCurrent max accuracy: " + Double.toString(maxAccuracy) + " (at iteration " + Integer.toString(maxAccuracyIter) + ")");
            }
            end_iter = System.currentTimeMillis();
            elapsed_iter = end_iter - start_iter;
            System.out.println("\tIteration elapsed (including testing & evaluation): " + Double.toString((double)elapsed_iter / 1000.0) + " seconds");
            if (!this.model.option.isLogging) continue;
            fout.println("\tIteration elapsed (including testing & evaluation): " + Double.toString((double)elapsed_iter / 1000.0) + " seconds");
            fout.flush();
        } while (this.iflag[0] != 0 && numIter < this.model.option.numIterations);
        long end_train = System.currentTimeMillis();
        long elapsed_train = end_train - start_train;
        System.out.println("\tThe training process elapsed: " + Double.toString((double)elapsed_train / 1000.0) + " seconds");
        if (this.model.option.isLogging) {
            fout.println("\tThe training process elapsed: " + Double.toString((double)elapsed_train / 1000.0) + " seconds");
        }
        if (this.model.option.evaluateDuringTraining && this.model.option.saveBestModel) {
            for (i = 0; i < this.numFeatures; ++i) {
                this.lambda[i] = this.tempLambda[i];
            }
        }
    }

    public double computeLogLiGradient(double[] lambda, double[] gradLogLi, int numIter, PrintWriter fout) {
        int i;
        double logLi = 0.0;
        for (i = 0; i < this.numFeatures; ++i) {
            gradLogLi[i] = -1.0 * lambda[i] / this.model.option.sigmaSquare;
            logLi -= lambda[i] * lambda[i] / (2.0 * this.model.option.sigmaSquare);
        }
        for (int ii = 0; ii < this.model.data.trnData.size(); ++ii) {
            Observation obsr = (Observation)this.model.data.trnData.get(ii);
            for (i = 0; i < this.numLabels; ++i) {
                this.temp[i] = 0.0;
            }
            double obsrLogLi = 0.0;
            this.model.feaGen.startScanFeatures(obsr);
            while (this.model.feaGen.hasNextFeature()) {
                Feature f = this.model.feaGen.nextFeature();
                if (f.label == obsr.humanLabel) {
                    int n = f.idx;
                    gradLogLi[n] = gradLogLi[n] + (double)f.val;
                    obsrLogLi += lambda[f.idx] * (double)f.val;
                }
                int n = f.label;
                this.temp[n] = this.temp[n] + lambda[f.idx] * (double)f.val;
            }
            double Zx = 0.0;
            for (i = 0; i < this.numLabels; ++i) {
                Zx += Math.exp(this.temp[i]);
            }
            this.model.feaGen.scanReset();
            while (this.model.feaGen.hasNextFeature()) {
                Feature f = this.model.feaGen.nextFeature();
                int n = f.idx;
                gradLogLi[n] = gradLogLi[n] - (double)f.val * Math.exp(this.temp[f.label]) / Zx;
            }
            logLi += (obsrLogLi -= Math.log(Zx));
        }
        System.out.println();
        System.out.println("Iteration: " + Integer.toString(numIter));
        System.out.println("\tLog-likelihood                 = " + Double.toString(logLi));
        double gradLogLiNorm = Train.norm(gradLogLi);
        System.out.println("\tNorm (log-likelihood gradient) = " + Double.toString(gradLogLiNorm));
        double lambdaNorm = Train.norm(lambda);
        System.out.println("\tNorm (lambda)                  = " + Double.toString(lambdaNorm));
        if (this.model.option.isLogging) {
            fout.println();
            fout.println("Iteration: " + Integer.toString(numIter));
            fout.println("\tLog-likelihood                 = " + Double.toString(logLi));
            fout.println("\tNorm (log-likelihood gradient) = " + Double.toString(gradLogLiNorm));
            fout.println("\tNorm (lambda)                  = " + Double.toString(lambdaNorm));
        }
        return logLi;
    }
}

