package com.datumbox.framework.core.machinelearning.common.abstracts.algorithms;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.MapMethods;
import com.datumbox.framework.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractCluster;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractModelParameters;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM.AbstractTrainingParameters;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractValidationMetrics;
import com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable;
import com.datumbox.framework.core.machinelearning.common.validators.ClustererValidator;
import com.datumbox.framework.core.statistics.descriptivestatistics.Descriptives;
import com.datumbox.framework.core.statistics.sampling.SimpleRandomSampling;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.math3.random.EmpiricalDistribution;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractDPMM.class */
public abstract class AbstractDPMM<CL extends AbstractCluster, MP extends AbstractModelParameters, TP extends AbstractTrainingParameters, VM extends AbstractClusterer.AbstractValidationMetrics> extends AbstractClusterer<CL, MP, TP, VM> implements PredictParallelizable {
    private boolean parallelized;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractDPMM$AbstractCluster.class */
    public static abstract class AbstractCluster extends AbstractClusterer.AbstractCluster {
        protected transient Map<Object, Integer> featureIds;

        /* JADX INFO: Access modifiers changed from: protected */
        public AbstractCluster(Integer num) {
            super(num);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public AbstractCluster(Integer num, AbstractCluster abstractCluster) {
            super(num, abstractCluster);
            this.featureIds = abstractCluster.featureIds;
        }

        protected abstract Map<Object, Integer> getFeatureIds();

        protected abstract void setFeatureIds(Map<Object, Integer> map);

        protected abstract void updateClusterParameters();

        protected abstract double posteriorLogPdf(Record record);

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected abstract void add(Record record);

        @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractClusterer.AbstractCluster
        protected abstract void remove(Record record);

        protected abstract AbstractCluster copy2new(Integer num);
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractDPMM$AbstractModelParameters.class */
    public static abstract class AbstractModelParameters<CL extends AbstractCluster> extends AbstractClusterer.AbstractModelParameters<CL> {
        private int totalIterations;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Integer> featureIds;

        /* JADX INFO: Access modifiers changed from: protected */
        public AbstractModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public int getTotalIterations() {
            return this.totalIterations;
        }

        protected void setTotalIterations(int i) {
            this.totalIterations = i;
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractDPMM$AbstractTrainingParameters.class */
    public static abstract class AbstractTrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private double alpha;
        private int maxIterations = EmpiricalDistribution.DEFAULT_BIN_COUNT;
        private Initialization initializationMethod = Initialization.ONE_CLUSTER_PER_RECORD;

        /* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/algorithms/AbstractDPMM$AbstractTrainingParameters$Initialization.class */
        public enum Initialization {
            ONE_CLUSTER_PER_RECORD,
            RANDOM_ASSIGNMENT
        }

        public double getAlpha() {
            return this.alpha;
        }

        public void setAlpha(double d) {
            this.alpha = d;
        }

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int i) {
            this.maxIterations = i;
        }

        public Initialization getInitializationMethod() {
            return this.initializationMethod;
        }

        public void setInitializationMethod(Initialization initialization) {
            this.initializationMethod = initialization;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractDPMM(String str, Configuration configuration, Class<MP> cls, Class<TP> cls2, Class<VM> cls3) {
        super(str, configuration, cls, cls2, cls3, new ClustererValidator());
        this.parallelized = true;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public void _predictDataset(Dataframe dataframe) {
        DatabaseConnector dbc = kb().getDbc();
        Map<Integer, PredictParallelizable.Prediction> bigMap = dbc.getBigMap("tmp_resultsBuffer", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_DISK, true, true);
        _predictDatasetParallel(dataframe, bigMap, kb().getConf().getConcurrencyConfig());
        dbc.dropBigMap("tmp_resultsBuffer", bigMap);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.PredictParallelizable
    public PredictParallelizable.Prediction _predictRecord(Record record) {
        Map<Integer, CL> clusterMap = ((AbstractModelParameters) kb().getModelParameters()).getClusterMap();
        AssociativeArray associativeArray = new AssociativeArray();
        for (Integer num : clusterMap.keySet()) {
            associativeArray.put(num, Double.valueOf(getFromClusterMap(num.intValue(), clusterMap).posteriorLogPdf(record)));
        }
        Descriptives.normalizeExp(associativeArray);
        return new PredictParallelizable.Prediction(getSelectedClusterFromScores(associativeArray), associativeArray);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        AbstractModelParameters abstractModelParameters = (AbstractModelParameters) kb().getModelParameters();
        Set<Object> goldStandardClasses = abstractModelParameters.getGoldStandardClasses();
        Map<Object, Integer> featureIds = abstractModelParameters.getFeatureIds();
        int i = 0;
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Record next = it.next();
            Object y = next.getY();
            if (y != null) {
                goldStandardClasses.add(y);
            }
            Iterator<Map.Entry<Object, Object>> it2 = next.getX().entrySet().iterator();
            while (it2.hasNext()) {
                if (featureIds.putIfAbsent(it2.next().getKey(), Integer.valueOf(i)) == null) {
                    i++;
                }
            }
        }
        abstractModelParameters.setTotalIterations(collapsedGibbsSampling(dataframe));
        clearClusters();
    }

    private CL getFromClusterMap(int i, Map<Integer, CL> map) {
        CL cl = map.get(Integer.valueOf(i));
        if (cl.getFeatureIds() == null) {
            cl.setFeatureIds(((AbstractModelParameters) kb().getModelParameters()).getFeatureIds());
        }
        return cl;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private int collapsedGibbsSampling(Dataframe dataframe) {
        AbstractModelParameters abstractModelParameters = (AbstractModelParameters) kb().getModelParameters();
        Map<Integer, CL> bigMap = kb().getDbc().getBigMap("tmp_tempClusterMap", DatabaseConnector.MapType.HASHMAP, DatabaseConnector.StorageHint.IN_CACHE, false, true);
        bigMap.putAll(abstractModelParameters.getClusterMap());
        AbstractTrainingParameters abstractTrainingParameters = (AbstractTrainingParameters) kb().getTrainingParameters();
        double alpha = abstractTrainingParameters.getAlpha();
        Integer valueOf = Integer.valueOf(bigMap.size());
        if (abstractTrainingParameters.getInitializationMethod() == AbstractTrainingParameters.Initialization.ONE_CLUSTER_PER_RECORD) {
            for (Map.Entry<Integer, Record> entry : dataframe.entries()) {
                Integer key = entry.getKey();
                Record value = entry.getValue();
                CL createNewCluster = createNewCluster(valueOf);
                createNewCluster.add(value);
                bigMap.put(valueOf, createNewCluster);
                dataframe._unsafe_set(key, new Record(value.getX(), value.getY(), valueOf, value.getYPredictedProbabilities()));
                valueOf = Integer.valueOf(valueOf.intValue() + 1);
            }
        } else {
            int max = (int) (Math.max(alpha, 1.0d) * Math.log(dataframe.size()));
            if (max <= 0) {
                max = 1;
            }
            for (int i = 0; i < max; i++) {
                bigMap.put(valueOf, createNewCluster(valueOf));
                valueOf = Integer.valueOf(valueOf.intValue() + 1);
            }
            int intValue = valueOf.intValue();
            for (Map.Entry<Integer, Record> entry2 : dataframe.entries()) {
                Integer key2 = entry2.getKey();
                Record value2 = entry2.getValue();
                Integer valueOf2 = Integer.valueOf(PHPMethods.mt_rand(0, intValue - 1));
                Record record = new Record(value2.getX(), value2.getY(), valueOf2, value2.getYPredictedProbabilities());
                dataframe._unsafe_set(key2, record);
                CL fromClusterMap = getFromClusterMap(valueOf2.intValue(), bigMap);
                fromClusterMap.add(record);
                bigMap.put(valueOf2, fromClusterMap);
            }
        }
        int size = bigMap.size();
        int maxIterations = abstractTrainingParameters.getMaxIterations();
        boolean z = false;
        int i2 = 0;
        while (i2 < maxIterations && !z) {
            this.logger.debug("Iteration {}", Integer.valueOf(i2));
            z = true;
            for (Map.Entry<Integer, Record> entry3 : dataframe.entries()) {
                Integer key3 = entry3.getKey();
                Record value3 = entry3.getValue();
                Integer num = (Integer) value3.getYPredicted();
                CL fromClusterMap2 = getFromClusterMap(num.intValue(), bigMap);
                fromClusterMap2.remove(value3);
                if (fromClusterMap2.size() == 0) {
                    bigMap.remove(num);
                } else {
                    bigMap.put(num, fromClusterMap2);
                }
                AssociativeArray clusterProbabilities = clusterProbabilities(value3, size, bigMap);
                CL createNewCluster2 = createNewCluster(valueOf);
                clusterProbabilities.put(valueOf, Double.valueOf(createNewCluster2.posteriorLogPdf(value3) + Math.log(alpha / ((alpha + size) - 1.0d))));
                Descriptives.normalizeExp(clusterProbabilities);
                Integer num2 = (Integer) SimpleRandomSampling.weightedSampling(clusterProbabilities, 1, true).iterator().next();
                if (Objects.equals(num2, valueOf)) {
                    Record record2 = new Record(value3.getX(), value3.getY(), valueOf, value3.getYPredictedProbabilities());
                    dataframe._unsafe_set(key3, record2);
                    createNewCluster2.add(record2);
                    bigMap.put(valueOf, createNewCluster2);
                    z = false;
                    valueOf = Integer.valueOf(valueOf.intValue() + 1);
                } else {
                    if (!Objects.equals(num, num2)) {
                        value3 = new Record(value3.getX(), value3.getY(), num2, value3.getYPredictedProbabilities());
                        dataframe._unsafe_set(key3, value3);
                        z = false;
                    }
                    CL fromClusterMap3 = getFromClusterMap(num2.intValue(), bigMap);
                    fromClusterMap3.add(value3);
                    bigMap.put(num2, fromClusterMap3);
                }
            }
            i2++;
        }
        Map<Integer, CL> clusterMap = abstractModelParameters.getClusterMap();
        int size2 = clusterMap.size();
        Iterator<CL> it = bigMap.values().iterator();
        while (it.hasNext()) {
            clusterMap.put(Integer.valueOf(size2), it.next().copy2new(Integer.valueOf(size2)));
            size2++;
        }
        kb().getDbc().dropBigMap("tmp_tempClusterMap", bigMap);
        return i2;
    }

    private AssociativeArray clusterProbabilities(Record record, int i, Map<Integer, CL> map) {
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        double alpha = ((AbstractTrainingParameters) kb().getTrainingParameters()).getAlpha();
        for (Integer num : map.keySet()) {
            concurrentHashMap.put(num, Double.valueOf(getFromClusterMap(num.intValue(), map).posteriorLogPdf(record) + Math.log(r0.size() / ((alpha + i) - 1.0d))));
        }
        return new AssociativeArray(concurrentHashMap);
    }

    private Object getSelectedClusterFromScores(AssociativeArray associativeArray) {
        return MapMethods.selectMaxKeyValue(associativeArray).getKey();
    }

    protected abstract CL createNewCluster(Integer num);
}
