package com.datumbox.framework.core.machinelearning.recommendersystem;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.dataobjects.TransposeDataList;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractRecommender;
import com.datumbox.framework.core.machinelearning.common.validators.CollaborativeFilteringValidator;
import com.datumbox.framework.core.mathematics.distances.Distance;
import com.datumbox.framework.core.statistics.parametrics.relatedsamples.PearsonCorrelation;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFiltering.class */
public class CollaborativeFiltering extends AbstractRecommender<ModelParameters, TrainingParameters, ValidationMetrics> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFiltering$ModelParameters.class */
    public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_CACHE, concurrent = false)
        private Map<List<Object>, Double> similarities;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<List<Object>, Double> getSimilarities() {
            return this.similarities;
        }

        protected void setSimilarities(Map<List<Object>, Double> map) {
            this.similarities = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFiltering$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private SimilarityMeasure similarityMethod = SimilarityMeasure.EUCLIDIAN;

        /* loaded from: input_file:com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFiltering$TrainingParameters$SimilarityMeasure.class */
        public enum SimilarityMeasure {
            EUCLIDIAN,
            MANHATTAN,
            PEARSONS_CORRELATION
        }

        public SimilarityMeasure getSimilarityMethod() {
            return this.similarityMethod;
        }

        public void setSimilarityMethod(SimilarityMeasure similarityMeasure) {
            this.similarityMethod = similarityMeasure;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/recommendersystem/CollaborativeFiltering$ValidationMetrics.class */
    public static class ValidationMetrics extends AbstractModeler.AbstractValidationMetrics {
        private static final long serialVersionUID = 1;
        private double RMSE = 0.0d;

        public double getRMSE() {
            return this.RMSE;
        }

        public void setRMSE(double d) {
            this.RMSE = d;
        }
    }

    public CollaborativeFiltering(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class, new CollaborativeFilteringValidator());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public void _predictDataset(Dataframe dataframe) {
        _predictDataset(dataframe, false);
    }

    private void _predictDataset(Dataframe dataframe, boolean z) {
        Map<List<Object>, Double> similarities = ((ModelParameters) kb().getModelParameters()).getSimilarities();
        for (Map.Entry<Integer, Record> entry : dataframe.entries()) {
            Integer key = entry.getKey();
            Record value = entry.getValue();
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            for (Map.Entry<Object, Object> entry2 : value.getX().entrySet()) {
                Object key2 = entry2.getKey();
                Double d = TypeInference.toDouble(entry2.getValue());
                for (Map.Entry<List<Object>, Double> entry3 : similarities.entrySet()) {
                    List<Object> key3 = entry3.getKey();
                    if (key3.get(0).equals(key2)) {
                        Object obj = key3.get(1);
                        Double d2 = TypeInference.toDouble(hashMap.get(obj));
                        Double d3 = (Double) hashMap2.get(obj);
                        if (d2 == null) {
                            d2 = Double.valueOf(0.0d);
                            d3 = Double.valueOf(0.0d);
                        }
                        Double value2 = entry3.getValue();
                        if (z || !value.getX().containsKey(obj)) {
                            hashMap.put(obj, Double.valueOf(d2.doubleValue() + (value2.doubleValue() * d.doubleValue())));
                        }
                        hashMap2.put(obj, Double.valueOf(d3.doubleValue() + value2.doubleValue()));
                    }
                }
            }
            for (Map.Entry entry4 : hashMap.entrySet()) {
                Object key4 = entry4.getKey();
                hashMap.put(key4, Double.valueOf(TypeInference.toDouble(entry4.getValue()).doubleValue() / ((Double) hashMap2.get(key4)).doubleValue()));
            }
            Map sortNumberMapByValueDescending = MapMethods.sortNumberMapByValueDescending(hashMap);
            dataframe._unsafe_set(key, new Record(value.getX(), value.getY(), sortNumberMapByValueDescending.keySet().iterator().next(), new AssociativeArray(sortNumberMapByValueDescending)));
        }
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        Map<List<Object>, Double> similarities = ((ModelParameters) kb().getModelParameters()).getSimilarities();
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Record next = it.next();
            Object y = next.getY();
            Iterator<Record> it2 = dataframe.iterator();
            while (it2.hasNext()) {
                Record next2 = it2.next();
                Object y2 = next2.getY();
                List<Object> asList = Arrays.asList(y, y2);
                if (!similarities.containsKey(asList)) {
                    double calculateSimilarity = calculateSimilarity(next, next2);
                    similarities.put(asList, Double.valueOf(calculateSimilarity));
                    similarities.put(Arrays.asList(y2, y), Double.valueOf(calculateSimilarity));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public ValidationMetrics validateModel(Dataframe dataframe) {
        _predictDataset(dataframe, true);
        ValidationMetrics validationMetrics = (ValidationMetrics) kb().getEmptyValidationMetricsObject();
        double d = 0.0d;
        int i = 0;
        Iterator<Record> it = dataframe.iterator();
        while (it.hasNext()) {
            Record next = it.next();
            AssociativeArray yPredictedProbabilities = next.getYPredictedProbabilities();
            for (Map.Entry<Object, Object> entry : next.getX().entrySet()) {
                d += Math.pow(TypeInference.toDouble(entry.getValue()).doubleValue() - TypeInference.toDouble(yPredictedProbabilities.get(entry.getKey())).doubleValue(), 2.0d);
                i++;
            }
        }
        validationMetrics.setRMSE(Math.sqrt(d / i));
        return validationMetrics;
    }

    private double calculateSimilarity(Record record, Record record2) {
        double calculateCorrelation;
        TrainingParameters.SimilarityMeasure similarityMethod = ((TrainingParameters) kb().getTrainingParameters()).getSimilarityMethod();
        if (similarityMethod == TrainingParameters.SimilarityMeasure.EUCLIDIAN) {
            calculateCorrelation = 1.0d / (1.0d + Distance.euclidean(record.getX(), record2.getX()));
        } else if (similarityMethod == TrainingParameters.SimilarityMeasure.MANHATTAN) {
            calculateCorrelation = 1.0d / (1.0d + Distance.manhattan(record.getX(), record2.getX()));
        } else {
            if (similarityMethod != TrainingParameters.SimilarityMeasure.PEARSONS_CORRELATION) {
                throw new IllegalArgumentException("Unsupported Distance method.");
            }
            HashSet hashSet = new HashSet(record.getX().keySet());
            hashSet.addAll(record2.getX().keySet());
            FlatDataList flatDataList = new FlatDataList();
            FlatDataList flatDataList2 = new FlatDataList();
            for (Object obj : hashSet) {
                flatDataList.add(TypeInference.toDouble(record.getX().get(obj)));
                flatDataList2.add(TypeInference.toDouble(record2.getX().get(obj)));
            }
            TransposeDataList transposeDataList = new TransposeDataList();
            transposeDataList.put(1, flatDataList);
            transposeDataList.put(2, flatDataList2);
            calculateCorrelation = (PearsonCorrelation.calculateCorrelation(transposeDataList) + 1.0d) / 2.0d;
        }
        return calculateCorrelation;
    }
}
