package com.datumbox.framework.core.machinelearning.regression;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.MatrixDataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractLinearRegression;
import com.datumbox.framework.core.machinelearning.common.interfaces.StepwiseCompatible;
import com.datumbox.framework.core.statistics.distributions.ContinuousDistributions;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegression.class */
public class MatrixLinearRegression extends AbstractLinearRegression<ModelParameters, TrainingParameters, ValidationMetrics> implements StepwiseCompatible {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegression$ModelParameters.class */
    public static class ModelParameters extends AbstractLinearRegression.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Integer> featureIds;
        private Map<Object, Double> featurePvalues;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        public Map<Object, Double> getFeaturePvalues() {
            return this.featurePvalues;
        }

        protected void setFeaturePvalues(Map<Object, Double> map) {
            this.featurePvalues = map;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegression$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/MatrixLinearRegression$ValidationMetrics.class */
    public static class ValidationMetrics extends AbstractLinearRegression.AbstractValidationMetrics {
        private static final long serialVersionUID = 1;
    }

    public MatrixLinearRegression(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class, ValidationMetrics.class);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public void _predictDataset(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        int intValue = modelParameters.getD().intValue() + 1;
        Map<Object, Double> thitas = modelParameters.getThitas();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        ArrayRealVector arrayRealVector = new ArrayRealVector(intValue);
        for (Map.Entry<Object, Double> entry : thitas.entrySet()) {
            arrayRealVector.setEntry(featureIds.get(entry.getKey()).intValue(), entry.getValue().doubleValue());
        }
        HashMap hashMap = new HashMap();
        RealVector operate = MatrixDataframe.parseDataset(dataframe, hashMap, featureIds).getX().operate(arrayRealVector);
        for (Map.Entry<Integer, Record> entry2 : dataframe.entries()) {
            Integer key = entry2.getKey();
            Record value = entry2.getValue();
            dataframe._unsafe_set(key, new Record(value.getX(), value.getY(), Double.valueOf(operate.getEntry(((Integer) hashMap.get(key)).intValue())), value.getYPredictedProbabilities()));
        }
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        int intValue = modelParameters.getN().intValue();
        int intValue2 = modelParameters.getD().intValue();
        Map<Object, Double> thitas = modelParameters.getThitas();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        MatrixDataframe newInstance = MatrixDataframe.newInstance(dataframe, true, null, featureIds);
        RealVector y = newInstance.getY();
        RealMatrix x = newInstance.getX();
        RealMatrix transpose = x.transpose();
        RealMatrix inverse = new LUDecomposition(transpose.multiply(x)).getSolver().getInverse();
        RealVector operate = inverse.multiply(transpose).operate(y);
        thitas.put(Dataframe.COLUMN_NAME_CONSTANT, Double.valueOf(operate.getEntry(0)));
        for (Map.Entry<Object, Integer> entry : featureIds.entrySet()) {
            thitas.put(entry.getKey(), Double.valueOf(operate.getEntry(entry.getValue().intValue())));
        }
        double d = 0.0d;
        for (double d2 : x.operate(operate).subtract(y).toArray()) {
            d += d2 * d2;
        }
        RealMatrix scalarMultiply = inverse.scalarMultiply(d / (intValue - (intValue2 + 1)));
        Map array_flip = PHPMethods.array_flip(featureIds);
        HashMap hashMap = new HashMap();
        for (int i = 0; i < intValue2 + 1; i++) {
            double entry2 = scalarMultiply.getEntry(i, i);
            Object obj = array_flip.get(Integer.valueOf(i));
            if (entry2 <= 0.0d) {
                hashMap.put(obj, Double.valueOf(0.0d));
            } else {
                hashMap.put(obj, Double.valueOf(1.0d - ContinuousDistributions.studentsCdf(operate.getEntry(i) / Math.sqrt(entry2), intValue - (intValue2 + 1))));
            }
        }
        modelParameters.setFeaturePvalues(hashMap);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.StepwiseCompatible
    public Map<Object, Double> getFeaturePvalues() {
        return ((ModelParameters) kb().getModelParameters()).getFeaturePvalues();
    }
}
