package com.datumbox.framework.core.machinelearning.regression;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractRegressor;
import com.datumbox.framework.core.machinelearning.common.interfaces.StepwiseCompatible;
import java.util.HashSet;
import java.util.Map;
import org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/StepwiseRegression.class */
public class StepwiseRegression extends AbstractRegressor<ModelParameters, TrainingParameters, AbstractRegressor.ValidationMetrics> {
    private transient AbstractRegressor mlregressor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/StepwiseRegression$ModelParameters.class */
    public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/regression/StepwiseRegression$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private Integer maxIterations = null;
        private Double aout = Double.valueOf(0.05d);
        private Class<? extends AbstractRegressor> regressionClass;
        private AbstractTrainer.AbstractTrainingParameters regressionTrainingParameters;

        public Integer getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(Integer num) {
            this.maxIterations = num;
        }

        public Double getAout() {
            return this.aout;
        }

        public void setAout(Double d) {
            this.aout = d;
        }

        public Class<? extends AbstractRegressor> getRegressionClass() {
            return this.regressionClass;
        }

        public void setRegressionClass(Class<? extends AbstractRegressor> cls) {
            if (!StepwiseCompatible.class.isAssignableFrom(cls)) {
                throw new IllegalArgumentException("The regression model is not Stepwise Compatible.");
            }
            this.regressionClass = cls;
        }

        public AbstractTrainer.AbstractTrainingParameters getRegressionTrainingParameters() {
            return this.regressionTrainingParameters;
        }

        public void setRegressionTrainingParameters(AbstractTrainer.AbstractTrainingParameters abstractTrainingParameters) {
            this.regressionTrainingParameters = abstractTrainingParameters;
        }
    }

    public StepwiseRegression(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class, AbstractRegressor.ValidationMetrics.class, null);
        this.mlregressor = null;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, com.datumbox.framework.common.interfaces.Trainable
    public void delete() {
        loadRegressor();
        this.mlregressor.delete();
        this.mlregressor = null;
        super.delete();
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, java.lang.AutoCloseable
    public void close() {
        loadRegressor();
        this.mlregressor.close();
        this.mlregressor = null;
        super.close();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public AbstractRegressor.ValidationMetrics kFoldCrossValidation(Dataframe dataframe, TrainingParameters trainingParameters, int i) {
        if (this.mlregressor == null) {
            throw new RuntimeException("You need to train a Regressor before running k-fold cross validation.");
        }
        return (AbstractRegressor.ValidationMetrics) this.mlregressor.kFoldCrossValidation(dataframe, trainingParameters, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public AbstractRegressor.ValidationMetrics validateModel(Dataframe dataframe) {
        loadRegressor();
        return (AbstractRegressor.ValidationMetrics) this.mlregressor.validate(dataframe);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler
    public void _predictDataset(Dataframe dataframe) {
        loadRegressor();
        this.mlregressor.predict(dataframe);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        TrainingParameters trainingParameters = (TrainingParameters) kb().getTrainingParameters();
        Integer maxIterations = trainingParameters.getMaxIterations();
        if (maxIterations == null) {
            maxIterations = Integer.valueOf(BaseAbstractUnivariateIntegrator.DEFAULT_MAX_ITERATIONS_COUNT);
        }
        double doubleValue = trainingParameters.getAout().doubleValue();
        Dataframe copy2 = dataframe.copy2();
        for (int i = 0; i < maxIterations.intValue(); i++) {
            Map<Object, Double> runRegression = runRegression(copy2);
            if (runRegression.isEmpty()) {
                break;
            }
            runRegression.remove(Dataframe.COLUMN_NAME_CONSTANT);
            Map.Entry<Object, Double> selectMaxKeyValue = MapMethods.selectMaxKeyValue(runRegression);
            if (selectMaxKeyValue.getValue().doubleValue() <= doubleValue) {
                break;
            }
            HashSet hashSet = new HashSet();
            hashSet.add(selectMaxKeyValue.getKey());
            copy2.dropXColumns(hashSet);
            if (copy2.xColumnSize() == 0) {
                break;
            }
        }
        this.mlregressor = generateRegressor();
        this.mlregressor.fit(copy2, (Dataframe) trainingParameters.getRegressionTrainingParameters());
        copy2.delete();
    }

    private void loadRegressor() {
        if (this.mlregressor == null) {
            this.mlregressor = generateRegressor();
        }
    }

    private AbstractRegressor generateRegressor() {
        return (AbstractRegressor) Trainable.newInstance(((TrainingParameters) kb().getTrainingParameters()).getRegressionClass(), this.dbName, kb().getConf());
    }

    private Map<Object, Double> runRegression(Dataframe dataframe) {
        TrainingParameters trainingParameters = (TrainingParameters) kb().getTrainingParameters();
        this.mlregressor = generateRegressor();
        this.mlregressor.fit(dataframe, (Dataframe) trainingParameters.getRegressionTrainingParameters());
        Map<Object, Double> featurePvalues = ((StepwiseCompatible) this.mlregressor).getFeaturePvalues();
        this.mlregressor.delete();
        return featurePvalues;
    }
}
