package com.datumbox.framework.core.machinelearning.classification;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClassifier;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainParallelizable;
import com.datumbox.framework.core.machinelearning.common.validators.ClassifierValidator;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/MaximumEntropy.class */
public class MaximumEntropy extends AbstractClassifier<ModelParameters, TrainingParameters, ValidationMetrics> implements PredictParallelizable, TrainParallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/MaximumEntropy$ModelParameters.class */
    public static class ModelParameters extends AbstractClassifier.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_MEMORY, concurrent = true)
        private Map<List<Object>, Double> lambdas;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<List<Object>, Double> getLambdas() {
            return this.lambdas;
        }

        protected void setLambdas(Map<List<Object>, Double> map) {
            this.lambdas = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/MaximumEntropy$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private int totalIterations = 100;

        public int getTotalIterations() {
            return this.totalIterations;
        }

        public void setTotalIterations(int i) {
            this.totalIterations = i;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/classification/MaximumEntropy$ValidationMetrics.class */
    public static class ValidationMetrics extends AbstractClassifier.AbstractValidationMetrics {
        private static final long serialVersionUID = 1;
    }

    public MaximumEntropy(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new ClassifierValidator());
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(kb().getConf().getConcurrencyConfig());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public void _predictDataset(Dataframe dataframe) {
        DatabaseConnector dbc = kb().getDbc();
        Map<Integer, PredictParallelizable.Prediction> bigMap = dbc.getBigMap("tmp_resultsBuffer", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_DISK, true, true);
        _predictDatasetParallel(dataframe, bigMap, kb().getConf().getConcurrencyConfig());
        dbc.dropBigMap("tmp_resultsBuffer", bigMap);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        Set<Object> classes = ((ModelParameters) kb().getModelParameters()).getClasses();
        AssociativeArray associativeArray = new AssociativeArray();
        for (Object obj : classes) {
            associativeArray.put(obj, calculateClassScore(record.getX(), obj));
        }
        Object selectedClassFromClassScores = getSelectedClassFromClassScores(associativeArray);
        Descriptives.normalizeExp(associativeArray);
        return new PredictParallelizable.Prediction(selectedClassFromClassScores, associativeArray);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        int intValue = modelParameters.getN().intValue();
        Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
        Set<Object> classes = modelParameters.getClasses();
        double d = 0.0d;
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Record next = it.next();
            classes.add(next.getY());
            int count = (int) next.getX().values().stream().filter(obj -> {
                return obj != null && TypeInference.toDouble(obj).doubleValue() > 0.0d;
            }).count();
            if (count > d) {
                d = count;
            }
        }
        DatabaseConnector dbc = kb().getDbc();
        Map<List<Object>, Double> bigMap = dbc.getBigMap("tmp_EpFj_observed", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_MEMORY, true, true);
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.getXDataTypes().keySet().stream(), isParallelized()), obj2 -> {
            Iterator it2 = classes.iterator();
            while (it2.hasNext()) {
                List asList = Arrays.asList(obj2, it2.next());
                bigMap.put(asList, Double.valueOf(0.0d));
                lambdas.put(asList, Double.valueOf(0.0d));
            }
        });
        double d2 = 1.0d / intValue;
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            Object y = record.getY();
            for (Map.Entry<Object, Object> entry : record.getX().entrySet()) {
                Double d3 = TypeInference.toDouble(entry.getValue());
                if (d3 != null && d3.doubleValue() > 0.0d) {
                    List asList = Arrays.asList(entry.getKey(), y);
                    synchronized (bigMap) {
                        bigMap.put(asList, Double.valueOf(((Double) bigMap.get(asList)).doubleValue() + d2));
                    }
                }
            }
        });
        IIS(dataframe, bigMap, d);
        dbc.dropBigMap("tmp_EpFj_observed", bigMap);
    }

    private void IIS(Dataframe dataframe, Map<List<Object>, Double> map, double d) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        int totalIterations = ((TrainingParameters) kb().getTrainingParameters()).getTotalIterations();
        Set<Object> classes = modelParameters.getClasses();
        Map<List<Object>, Double> lambdas = modelParameters.getLambdas();
        int intValue = modelParameters.getN().intValue();
        DatabaseConnector dbc = kb().getDbc();
        for (int i = 0; i < totalIterations; i++) {
            this.logger.debug("Iteration {}", Integer.valueOf(i));
            Map bigMap = dbc.getBigMap("tmp_EpFj_model", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_MEMORY, false, true);
            this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
                AssociativeArray associativeArray = new AssociativeArray();
                AssociativeArray x = record.getX();
                for (Object obj : classes) {
                    associativeArray.put(obj, Double.valueOf(calculateClassScore(x, obj).doubleValue()));
                }
                Descriptives.normalizeExp(associativeArray);
                for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
                    Object key = entry.getKey();
                    double doubleValue = TypeInference.toDouble(entry.getValue()).doubleValue() / intValue;
                    synchronized (bigMap) {
                        for (Map.Entry<Object, Object> entry2 : record.getX().entrySet()) {
                            Double d2 = TypeInference.toDouble(entry2.getValue());
                            if (d2 != null && d2.doubleValue() != 0.0d) {
                                List asList = Arrays.asList(entry2.getKey(), key);
                                bigMap.put(asList, Double.valueOf(((Double) bigMap.getOrDefault(asList, Double.valueOf(0.0d))).doubleValue() + doubleValue));
                            }
                        }
                    }
                }
            });
            AtomicBoolean atomicBoolean = new AtomicBoolean(false);
            this.streamExecutor.forEach(StreamMethods.stream(bigMap.entrySet().stream(), isParallelized()), entry -> {
                List list = (List) entry.getKey();
                Double d2 = (Double) map.get(list);
                Double d3 = (Double) entry.getValue();
                if (Math.abs(d2.doubleValue() - d3.doubleValue()) <= 1.0E-8d) {
                    return;
                }
                if (d2.doubleValue() == 0.0d) {
                    lambdas.put(list, Double.valueOf(Double.NEGATIVE_INFINITY));
                    atomicBoolean.set(true);
                } else if (d3.doubleValue() == 0.0d) {
                    lambdas.put(list, Double.valueOf(Double.POSITIVE_INFINITY));
                    atomicBoolean.set(true);
                } else {
                    lambdas.put(list, Double.valueOf(((Double) lambdas.get(list)).doubleValue() + (Math.log(d2.doubleValue() / d3.doubleValue()) / d)));
                }
            });
            if (atomicBoolean.get()) {
                Double d2 = (Double) this.streamExecutor.min(StreamMethods.stream(lambdas.values().stream(), isParallelized()).filter(d3 -> {
                    return Double.isFinite(d3.doubleValue());
                }), (v0, v1) -> {
                    return Double.compare(v0, v1);
                }).get();
                Double d4 = (Double) this.streamExecutor.max(StreamMethods.stream(lambdas.values().stream(), isParallelized()).filter(d5 -> {
                    return Double.isFinite(d5.doubleValue());
                }), (v0, v1) -> {
                    return Double.compare(v0, v1);
                }).get();
                this.streamExecutor.forEach(StreamMethods.stream(lambdas.entrySet().stream(), isParallelized()), entry2 -> {
                    List list = (List) entry2.getKey();
                    Double d6 = (Double) entry2.getValue();
                    if (Double.isInfinite(d6.doubleValue())) {
                        if (d6.doubleValue() < 0.0d) {
                            lambdas.put(list, d2);
                        } else {
                            lambdas.put(list, d4);
                        }
                    }
                });
            }
            dbc.dropBigMap("tmp_EpFj_model", bigMap);
        }
    }

    private Double calculateClassScore(AssociativeArray associativeArray, Object obj) {
        Double d;
        double d2 = 0.0d;
        Map<List<Object>, Double> lambdas = ((ModelParameters) kb().getModelParameters()).getLambdas();
        for (Map.Entry<Object, Object> entry : associativeArray.entrySet()) {
            Double d3 = TypeInference.toDouble(entry.getValue());
            if (d3 != null && d3.doubleValue() != 0.0d && (d = lambdas.get(Arrays.asList(entry.getKey(), obj))) != null) {
                d2 += d.doubleValue();
            }
        }
        return Double.valueOf(d2);
    }
}
