package com.datumbox.framework.core.machinelearning.featureselection.scorebased;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractScoreBasedFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable;
import java.util.Iterator;
import java.util.Map;
import java.util.function.BiFunction;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDF.class */
public class TFIDF extends AbstractScoreBasedFeatureSelector<ModelParameters, TrainingParameters> implements Parallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDF$ModelParameters.class */
    public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Double> maxTFIDFfeatureScores;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<Object, Double> getMaxTFIDFfeatureScores() {
            return this.maxTFIDFfeatureScores;
        }

        protected void setMaxTFIDFfeatureScores(Map<Object, Double> map) {
            this.maxTFIDFfeatureScores = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/scorebased/TFIDF$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private boolean binarized = false;
        private Integer maxFeatures = null;

        public boolean isBinarized() {
            return this.binarized;
        }

        public void setBinarized(boolean z) {
            this.binarized = z;
        }

        public Integer getMaxFeatures() {
            return this.maxFeatures;
        }

        public void setMaxFeatures(Integer num) {
            this.maxFeatures = num;
        }
    }

    public TFIDF(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(kb().getConf().getConcurrencyConfig());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters) kb().getTrainingParameters();
        boolean isBinarized = trainingParameters.isBinarized();
        int intValue = modelParameters.getN().intValue();
        DatabaseConnector dbc = kb().getDbc();
        Map bigMap = dbc.getBigMap("tmp_idf", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_MEMORY, true, true);
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            for (Map.Entry<Object, Object> entry : it.next().getX().entrySet()) {
                Object key = entry.getKey();
                Double d = TypeInference.toDouble(entry.getValue());
                if (d != null && d.doubleValue() > 0.0d) {
                    bigMap.put(key, Double.valueOf(((Double) bigMap.getOrDefault(key, Double.valueOf(0.0d))).doubleValue() + 1.0d));
                }
            }
        }
        this.streamExecutor.forEach(StreamMethods.stream(bigMap.entrySet().stream(), isParallelized()), entry2 -> {
            bigMap.put(entry2.getKey(), Double.valueOf(Math.log10(intValue / ((Double) entry2.getValue()).doubleValue())));
        });
        Map<Object, Double> maxTFIDFfeatureScores = modelParameters.getMaxTFIDFfeatureScores();
        BiFunction biFunction = (obj, d2) -> {
            Double d2 = (Double) maxTFIDFfeatureScores.get(obj);
            return Boolean.valueOf(d2 == null || d2.doubleValue() < d2.doubleValue());
        };
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.stream(), isParallelized()), record -> {
            for (Map.Entry<Object, Object> entry3 : record.getX().entrySet()) {
                Object key2 = entry3.getKey();
                Double d3 = TypeInference.toDouble(entry3.getValue());
                if (d3 != null && d3.doubleValue() > 0.0d) {
                    if (isBinarized) {
                        d3 = Double.valueOf(1.0d);
                    }
                    double doubleValue = d3.doubleValue() * ((Double) bigMap.get(key2)).doubleValue();
                    if (doubleValue > 0.0d && ((Boolean) biFunction.apply(key2, Double.valueOf(doubleValue))).booleanValue()) {
                        synchronized (maxTFIDFfeatureScores) {
                            if (((Boolean) biFunction.apply(key2, Double.valueOf(doubleValue))).booleanValue()) {
                                maxTFIDFfeatureScores.put(key2, Double.valueOf(doubleValue));
                            }
                        }
                    }
                }
            }
        });
        dbc.dropBigMap("tmp_idf", bigMap);
        Integer maxFeatures = trainingParameters.getMaxFeatures();
        if (maxFeatures == null || maxFeatures.intValue() >= maxTFIDFfeatureScores.size()) {
            return;
        }
        AbstractScoreBasedFeatureSelector.selectHighScoreFeatures(maxTFIDFfeatureScores, maxFeatures);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    protected void filterFeatures(Dataframe dataframe) {
        DatabaseConnector dbc = kb().getDbc();
        Map<Object, Double> maxTFIDFfeatureScores = ((ModelParameters) kb().getModelParameters()).getMaxTFIDFfeatureScores();
        Map bigMap = dbc.getBigMap("tmp_removedColumns", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_MEMORY, false, true);
        for (Object obj : dataframe.getXDataTypes().keySet()) {
            if (!maxTFIDFfeatureScores.containsKey(obj)) {
                bigMap.put(obj, true);
            }
        }
        dataframe.dropXColumns(bigMap.keySet());
        dbc.dropBigMap("tmp_removedColumns", bigMap);
    }
}
