package com.datumbox.framework.core.machinelearning.featureselection.categorical;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCategoricalFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractScoreBasedFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformation.class */
public class MutualInformation extends AbstractCategoricalFeatureSelector<ModelParameters, TrainingParameters> implements Parallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformation$ModelParameters.class */
    public static class ModelParameters extends AbstractCategoricalFeatureSelector.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/categorical/MutualInformation$TrainingParameters.class */
    public static class TrainingParameters extends AbstractCategoricalFeatureSelector.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
    }

    public MutualInformation(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(kb().getConf().getConcurrencyConfig());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractCategoricalFeatureSelector
    protected void estimateFeatureScores(Map<Object, Integer> map, Map<List<Object>, Integer> map2, Map<Object, Double> map3) {
        this.logger.debug("estimateFeatureScores()");
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) kb().getTrainingParameters();
        Map<Object, Double> featureScores = modelParameters.getFeatureScores();
        double intValue = modelParameters.getN().intValue();
        double log = Math.log(2.0d);
        this.streamExecutor.forEach(StreamMethods.stream(map3.entrySet().stream(), isParallelized()), entry -> {
            Object key = entry.getKey();
            double doubleValue = ((Double) entry.getValue()).doubleValue();
            double d = intValue - doubleValue;
            double d2 = Double.NEGATIVE_INFINITY;
            for (Map.Entry entry : map.entrySet()) {
                Object key2 = entry.getKey();
                double intValue2 = ((Integer) entry.getValue()).intValue();
                double d3 = intValue - intValue2;
                Integer num = (Integer) map2.get(Arrays.asList(key, key2));
                double doubleValue2 = num != null ? num.doubleValue() : 0.0d;
                double d4 = intValue2 - doubleValue2;
                double d5 = d - d4;
                double d6 = doubleValue - doubleValue2;
                double log2 = doubleValue2 > 0.0d ? 0.0d + (((doubleValue2 / intValue) * Math.log((intValue / doubleValue) * (doubleValue2 / intValue2))) / log) : 0.0d;
                if (d4 > 0.0d) {
                    log2 += ((d4 / intValue) * Math.log((intValue / d) * (d4 / intValue2))) / log;
                }
                if (d6 > 0.0d) {
                    log2 += ((d6 / intValue) * Math.log((intValue / doubleValue) * (d6 / d3))) / log;
                }
                if (d5 > 0.0d) {
                    log2 += ((d5 / intValue) * Math.log((intValue / d) * (d5 / d3))) / log;
                }
                if (log2 > d2) {
                    d2 = log2;
                }
            }
            featureScores.put(key, Double.valueOf(d2));
        });
        Integer maxFeatures = trainingParameters.getMaxFeatures();
        if (maxFeatures == null || maxFeatures.intValue() >= featureScores.size()) {
            return;
        }
        AbstractScoreBasedFeatureSelector.selectHighScoreFeatures(featureScores, maxFeatures);
    }
}
