package org.openstreetmap.josm.plugins.osmrec.personalization;

import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.lucene.analysis.shingle.ShingleFilter;
import org.openstreetmap.josm.plugins.osmrec.container.OSMWay;
import org.openstreetmap.josm.plugins.osmrec.core.AbstractTrainWorker;
import org.openstreetmap.josm.plugins.osmrec.extractor.LanguageDetector;
import org.openstreetmap.josm.plugins.osmrec.features.ClassFeatures;
import org.openstreetmap.josm.plugins.osmrec.features.GeometryFeatures;
import org.openstreetmap.josm.plugins.osmrec.features.OSMClassification;
import org.openstreetmap.josm.plugins.osmrec.features.TextualFeatures;
import org.openstreetmap.josm.plugins.osmrec.parsers.Mapper;
import org.openstreetmap.josm.plugins.osmrec.parsers.Ontology;

/* loaded from: input_file:org/openstreetmap/josm/plugins/osmrec/personalization/TrainByUser.class */
public class TrainByUser extends AbstractTrainWorker {
    private final String username;

    public TrainByUser(String str, String str2, boolean z, double d, int i, int i2, boolean z2, LanguageDetector languageDetector, List<OSMWay> list) {
        super(str, z, d, i, i2, z2, languageDetector);
        this.username = str2;
        AbstractTrainWorker.wayList = list;
    }

    /* renamed from: doInBackground, reason: merged with bridge method [inline-methods] */
    public Void m798doInBackground() throws Exception {
        extractTextualList();
        parseFiles();
        if (this.validateFlag) {
            firePropertyChange("progress", Integer.valueOf(getProgress()), 5);
            validateLoop();
            firePropertyChange("progress", Integer.valueOf(getProgress()), 40);
            System.out.println("Training model with the best c: " + this.bestConfParam);
            clearDataset();
            trainModel(this.bestConfParam);
            firePropertyChange("progress", Integer.valueOf(getProgress()), 60);
            clearDataset();
            trainModelWithClasses(this.bestConfParam);
            firePropertyChange("progress", Integer.valueOf(getProgress()), 100);
            setProgress(100);
        } else {
            clearDataset();
            firePropertyChange("progress", Integer.valueOf(getProgress()), 10);
            trainModel(this.cParameterFromUser);
            firePropertyChange("progress", Integer.valueOf(getProgress()), 60);
            clearDataset();
            firePropertyChange("progress", Integer.valueOf(getProgress()), 65);
            trainModelWithClasses(this.cParameterFromUser);
            firePropertyChange("progress", Integer.valueOf(getProgress()), 100);
            setProgress(100);
            System.out.println("done.");
        }
        System.out.println("Train by user process complete.");
        return null;
    }

    private void parseFiles() {
        InputStream resourceAsStream = getClass().getResourceAsStream("/resources/files/Map");
        Mapper mapper = new Mapper();
        try {
            mapper.parseFile(resourceAsStream);
        } catch (FileNotFoundException e) {
            Logger.getLogger(Mapper.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        mappings = mapper.getMappings();
        mapperWithIDs = mapper.getMappingsWithIDs();
        Ontology ontology = new Ontology(getClass().getResourceAsStream("/resources/files/owl.xml"));
        ontology.parseOntology();
        System.out.println("ontology parsed ");
        indirectClasses = ontology.getIndirectClasses();
        indirectClassesWithIDs = ontology.getIndirectClassesIDs();
        FileInputStream fileInputStream = null;
        try {
            fileInputStream = new FileInputStream(new File(this.textualListFilePath));
        } catch (FileNotFoundException e2) {
            Logger.getLogger(getClass().getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        readTextualFromDefaultList(fileInputStream);
        this.numberOfTrainingInstances = wayList.size();
        System.out.println("number of instances: " + this.numberOfTrainingInstances);
        System.out.println("end of parsing files.");
        if (this.numberOfTrainingInstances == 0) {
            System.out.println("This user has not edited the loaded area. Cannot train a model!");
        }
    }

    public void validateLoop() {
        Double[] dArr = {Double.valueOf(Math.pow(2.0d, -10.0d)), Double.valueOf(Math.pow(2.0d, -10.0d)), Double.valueOf(Math.pow(2.0d, -5.0d)), Double.valueOf(Math.pow(2.0d, -3.0d)), Double.valueOf(Math.pow(2.0d, -1.0d)), Double.valueOf(Math.pow(2.0d, 0.0d))};
        double pow = Math.pow(2.0d, -10.0d);
        for (Double d : dArr) {
            foldScore1 = 0.0d;
            foldScore5 = 0.0d;
            foldScore10 = 0.0d;
            System.out.println("\n\n\nrunning for C = " + d);
            clearDataset();
            System.out.println("fold1");
            crossValidateFold(0, 4, 4, 5, false, d.doubleValue());
            foldScore1 += score1;
            foldScore5 += score5;
            foldScore10 += score10;
            clearDataset();
            System.out.println("fold2");
            crossValidateFold(1, 5, 0, 1, false, d.doubleValue());
            foldScore1 += score1;
            foldScore5 += score5;
            foldScore10 += score10;
            clearDataset();
            System.out.println("fold3");
            crossValidateFold(0, 5, 1, 2, true, d.doubleValue());
            foldScore1 += score1;
            foldScore5 += score5;
            foldScore10 += score10;
            clearDataset();
            System.out.println("fold4");
            crossValidateFold(0, 5, 2, 3, true, d.doubleValue());
            foldScore1 += score1;
            foldScore5 += score5;
            foldScore10 += score10;
            clearDataset();
            System.out.println("fold5");
            crossValidateFold(0, 5, 3, 4, true, d.doubleValue());
            foldScore1 += score1;
            foldScore5 += score5;
            foldScore10 += score10;
            System.out.println("\n\nC=" + d + ", average score 1-5-10: " + (foldScore1 / 5.0d) + ShingleFilter.TOKEN_SEPARATOR + (foldScore5 / 5.0d) + ShingleFilter.TOKEN_SEPARATOR + (foldScore10 / 5.0d));
            if (bestScore < foldScore1) {
                bestScore = foldScore1;
                pow = d.doubleValue();
            }
        }
        this.bestConfParam = pow;
        System.out.println("best c param= " + pow + ", score: " + (bestScore / 5.0d));
    }

    public void crossValidateFold(int i, int i2, int i3, int i4, boolean z, double d) {
        System.out.println("Starting cross validation");
        int size = wayList.size() / 5;
        ArrayList<OSMWay> arrayList = new ArrayList();
        int i5 = i * size;
        while (i5 < i2 * size) {
            if (z && i5 == i3 * size) {
                i5 = (i3 + 1) * size;
            }
            arrayList.add(wayList.get(i5));
            i5++;
        }
        int size2 = arrayList.size();
        System.out.println("trainList size: " + size2);
        int i6 = 0;
        for (OSMWay oSMWay : arrayList) {
            new OSMClassification().calculateClasses(oSMWay, mappings, mapperWithIDs, indirectClasses, indirectClassesWithIDs);
            if (oSMWay.getClassIDs().isEmpty()) {
                size2--;
            } else {
                i6 = (i6 + oSMWay.getClassIDs().size()) - 1;
            }
        }
        double[] dArr = new double[size2 + i6];
        FeatureNode[][] featureNodeArr = new FeatureNode[size2 + i6][numberOfFeatures];
        int i7 = 0;
        for (OSMWay oSMWay2 : arrayList) {
            GeometryFeatures geometryFeatures = new GeometryFeatures(1);
            geometryFeatures.createGeometryFeatures(oSMWay2);
            geometryFeatures.getLastID();
            new TextualFeatures(geometryFeatures.getLastID(), namesList, languageDetector).createTextualFeatures(oSMWay2);
            List<FeatureNode> featureNodeList = oSMWay2.getFeatureNodeList();
            FeatureNode[] featureNodeArr2 = new FeatureNode[featureNodeList.size()];
            if (!oSMWay2.getClassIDs().isEmpty()) {
                int i8 = 0;
                Iterator<FeatureNode> it = featureNodeList.iterator();
                while (it.hasNext()) {
                    featureNodeArr2[i8] = it.next();
                    i8++;
                }
                Iterator<Integer> it2 = oSMWay2.getClassIDs().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    featureNodeArr[i7] = featureNodeArr2;
                    dArr[i7] = intValue;
                    i7++;
                }
            }
        }
        Problem problem = new Problem();
        problem.l = size2 + i6;
        problem.n = numberOfFeatures;
        problem.x = featureNodeArr;
        problem.y = dArr;
        Parameter parameter = new Parameter(SolverType.getById(2), d, 0.001d);
        long nanoTime = System.nanoTime();
        System.out.println("training...");
        PrintStream printStream = System.out;
        System.setOut(new PrintStream(new OutputStream() { // from class: org.openstreetmap.josm.plugins.osmrec.personalization.TrainByUser.1
            @Override // java.io.OutputStream
            public void write(int i9) throws IOException {
            }
        }));
        Model train = Linear.train(problem, parameter);
        Long valueOf = Long.valueOf(System.nanoTime() - nanoTime);
        System.setOut(printStream);
        System.out.println("training process completed in: " + TimeUnit.NANOSECONDS.toSeconds(valueOf.longValue()) + " seconds.");
        File file = new File(this.modelDirectory.getAbsolutePath() + "/user_" + this.username + "_model_geometries_textual_c=" + d);
        if (file.exists()) {
            file.delete();
        }
        try {
            train.save(file);
            System.out.println("model saved at: " + file);
        } catch (IOException e) {
            Logger.getLogger(getClass().getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        ArrayList<OSMWay> arrayList2 = new ArrayList();
        for (int i9 = i3 * size; i9 < i4 * size; i9++) {
            arrayList2.add(wayList.get(i9));
        }
        System.out.println("testList size: " + arrayList2.size());
        int i10 = 0;
        int i11 = 0;
        int i12 = 0;
        try {
            train = Model.load(file);
        } catch (IOException e2) {
            Logger.getLogger(getClass().getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
        }
        int length = train.getLabels().length;
        int[] labels = train.getLabels();
        HashMap hashMap = new HashMap();
        for (int i13 = 0; i13 < train.getLabels().length; i13++) {
            hashMap.put(Integer.valueOf(labels[i13]), Integer.valueOf(i13));
        }
        int size3 = arrayList2.size();
        for (OSMWay oSMWay3 : arrayList2) {
            new OSMClassification().calculateClasses(oSMWay3, mappings, mapperWithIDs, indirectClasses, indirectClassesWithIDs);
            if (oSMWay3.getClassIDs().isEmpty()) {
                size3--;
            }
        }
        for (OSMWay oSMWay4 : arrayList2) {
            GeometryFeatures geometryFeatures2 = new GeometryFeatures(1);
            geometryFeatures2.createGeometryFeatures(oSMWay4);
            geometryFeatures2.getLastID();
            new TextualFeatures(geometryFeatures2.getLastID(), namesList, languageDetector).createTextualFeatures(oSMWay4);
            List<FeatureNode> featureNodeList2 = oSMWay4.getFeatureNodeList();
            FeatureNode[] featureNodeArr3 = new FeatureNode[featureNodeList2.size()];
            int i14 = 0;
            Iterator<FeatureNode> it3 = featureNodeList2.iterator();
            while (it3.hasNext()) {
                featureNodeArr3[i14] = it3.next();
                i14++;
            }
            double[] dArr2 = new double[length];
            Linear.predictValues(train, featureNodeArr3, dArr2);
            HashMap hashMap2 = new HashMap();
            for (int i15 = 0; i15 < dArr2.length; i15++) {
                hashMap2.put(Double.valueOf(dArr2[i15]), Integer.valueOf(i15));
            }
            Arrays.sort(dArr2);
            if (oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 1]))).intValue()]))) {
                i10++;
            }
            if (oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 1]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 2]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 3]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 4]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 5]))).intValue()]))) {
                i11++;
            }
            if (oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 1]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 2]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 3]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 4]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 5]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 6]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 7]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 8]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 9]))).intValue()])) || oSMWay4.getClassIDs().contains(Integer.valueOf(labels[((Integer) hashMap2.get(Double.valueOf(dArr2[dArr2.length - 10]))).intValue()]))) {
                i12++;
            }
        }
        System.out.println("Succeeded " + i10 + " of " + arrayList2.size() + " total (1 class prediction)");
        double d2 = i10 / size3;
        score1 = d2;
        System.out.println(d2);
        System.out.println("Succeeded " + i11 + " of " + arrayList2.size() + " total (5 class prediction)");
        double d3 = i11 / size3;
        score5 = d3;
        System.out.println(d3);
        System.out.println("Succeeded " + i12 + " of " + arrayList2.size() + " total (10 class prediction)");
        double d4 = i12 / size3;
        score10 = d4;
        System.out.println(d4);
    }

    private void trainModel(double d) {
        int size = wayList.size();
        System.out.println("trainList size: " + size);
        if (size == 0) {
            System.out.println("aborting training process..");
            return;
        }
        int i = 0;
        for (OSMWay oSMWay : wayList) {
            new OSMClassification().calculateClasses(oSMWay, mappings, mapperWithIDs, indirectClasses, indirectClassesWithIDs);
            if (oSMWay.getClassIDs().isEmpty()) {
                size--;
            } else {
                i = (i + oSMWay.getClassIDs().size()) - 1;
            }
        }
        double[] dArr = new double[size + i];
        FeatureNode[][] featureNodeArr = new FeatureNode[size + i][numberOfFeatures];
        int i2 = 0;
        for (OSMWay oSMWay2 : wayList) {
            GeometryFeatures geometryFeatures = new GeometryFeatures(1);
            geometryFeatures.createGeometryFeatures(oSMWay2);
            geometryFeatures.getLastID();
            new TextualFeatures(geometryFeatures.getLastID(), namesList, languageDetector).createTextualFeatures(oSMWay2);
            List<FeatureNode> featureNodeList = oSMWay2.getFeatureNodeList();
            FeatureNode[] featureNodeArr2 = new FeatureNode[featureNodeList.size()];
            if (!oSMWay2.getClassIDs().isEmpty()) {
                int i3 = 0;
                Iterator<FeatureNode> it = featureNodeList.iterator();
                while (it.hasNext()) {
                    featureNodeArr2[i3] = it.next();
                    i3++;
                }
                Iterator<Integer> it2 = oSMWay2.getClassIDs().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    featureNodeArr[i2] = featureNodeArr2;
                    dArr[i2] = intValue;
                    i2++;
                }
            }
        }
        Problem problem = new Problem();
        problem.l = size + i;
        problem.n = numberOfFeatures;
        problem.x = featureNodeArr;
        problem.y = dArr;
        Parameter parameter = new Parameter(SolverType.getById(2), d, 0.001d);
        long nanoTime = System.nanoTime();
        System.out.println("training...");
        PrintStream printStream = System.out;
        System.setOut(new PrintStream(new OutputStream() { // from class: org.openstreetmap.josm.plugins.osmrec.personalization.TrainByUser.2
            @Override // java.io.OutputStream
            public void write(int i4) throws IOException {
            }
        }));
        Model train = Linear.train(problem, parameter);
        Long valueOf = Long.valueOf(System.nanoTime() - nanoTime);
        System.setOut(printStream);
        System.out.println("training process completed in: " + TimeUnit.NANOSECONDS.toSeconds(valueOf.longValue()) + " seconds.");
        File file = new File(this.modelDirectory.getAbsolutePath() + "/best_model");
        File file2 = this.topKIsSelected ? new File(this.modelDirectory.getAbsolutePath() + "/" + this.inputFileName + "_model_c" + d + "_topK" + this.topK + "user" + this.username + ".0") : new File(this.modelDirectory.getAbsolutePath() + "/" + this.inputFileName + "_model_c" + d + "_maxF" + this.frequency + "user" + this.username + ".0");
        if (file.exists()) {
            file.delete();
        }
        if (file2.exists()) {
            file2.delete();
        }
        try {
            train.save(file);
            train.save(file2);
            System.out.println("best model saved at: " + file);
            System.out.println("custom model saved at: " + file2);
        } catch (IOException e) {
            Logger.getLogger(getClass().getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    private void trainModelWithClasses(double d) {
        int size = wayList.size();
        System.out.println("trainList size: " + size);
        if (size == 0) {
            System.out.println("aborting training process with classes..");
            return;
        }
        int i = 0;
        for (OSMWay oSMWay : wayList) {
            new OSMClassification().calculateClasses(oSMWay, mappings, mapperWithIDs, indirectClasses, indirectClassesWithIDs);
            if (oSMWay.getClassIDs().isEmpty()) {
                size--;
            } else {
                i = (i + oSMWay.getClassIDs().size()) - 1;
            }
        }
        double[] dArr = new double[size + i];
        FeatureNode[][] featureNodeArr = new FeatureNode[size + i][numberOfFeatures + 1422];
        int i2 = 0;
        for (OSMWay oSMWay2 : wayList) {
            new ClassFeatures().createClassFeatures(oSMWay2, mappings, mapperWithIDs, indirectClasses, indirectClassesWithIDs);
            GeometryFeatures geometryFeatures = new GeometryFeatures(1422);
            geometryFeatures.createGeometryFeatures(oSMWay2);
            geometryFeatures.getLastID();
            new TextualFeatures(geometryFeatures.getLastID(), namesList, languageDetector).createTextualFeatures(oSMWay2);
            List<FeatureNode> featureNodeList = oSMWay2.getFeatureNodeList();
            FeatureNode[] featureNodeArr2 = new FeatureNode[featureNodeList.size()];
            if (!oSMWay2.getClassIDs().isEmpty()) {
                int i3 = 0;
                Iterator<FeatureNode> it = featureNodeList.iterator();
                while (it.hasNext()) {
                    featureNodeArr2[i3] = it.next();
                    i3++;
                }
                Iterator<Integer> it2 = oSMWay2.getClassIDs().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    featureNodeArr[i2] = featureNodeArr2;
                    dArr[i2] = intValue;
                    i2++;
                }
            }
        }
        Problem problem = new Problem();
        problem.l = size + i;
        problem.n = numberOfFeatures + 1422;
        problem.x = featureNodeArr;
        problem.y = dArr;
        Parameter parameter = new Parameter(SolverType.getById(2), d, 0.001d);
        long nanoTime = System.nanoTime();
        System.out.println("training...");
        PrintStream printStream = System.out;
        System.setOut(new PrintStream(new OutputStream() { // from class: org.openstreetmap.josm.plugins.osmrec.personalization.TrainByUser.3
            @Override // java.io.OutputStream
            public void write(int i4) throws IOException {
            }
        }));
        Model train = Linear.train(problem, parameter);
        Long valueOf = Long.valueOf(System.nanoTime() - nanoTime);
        System.setOut(printStream);
        System.out.println("training process completed in: " + TimeUnit.NANOSECONDS.toSeconds(valueOf.longValue()) + " seconds.");
        File file = new File(this.modelDirectory.getAbsolutePath() + "/model_with_classes");
        File file2 = this.topKIsSelected ? new File(this.modelDirectory.getAbsolutePath() + "/" + this.inputFileName + "_model_c" + d + "_topK" + this.topK + "user" + this.username + ".1") : new File(this.modelDirectory.getAbsolutePath() + "/" + this.inputFileName + "_model_c" + d + "_maxF" + this.frequency + "user" + this.username + ".1");
        if (file2.exists()) {
            file2.delete();
        }
        if (file.exists()) {
            file.delete();
        }
        try {
            train.save(file);
            train.save(file2);
            System.out.println("model with classes saved at: " + file);
            System.out.println("custom model with classes saved at: " + file2);
        } catch (IOException e) {
            Logger.getLogger(getClass().getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    protected void done() {
        try {
            get();
            firePropertyChange("progress", Integer.valueOf(getProgress()), 100);
            setProgress(100);
        } catch (InterruptedException | ExecutionException e) {
            System.out.println("Exception: " + e);
        }
    }
}
