package nl.liacs.subdisc;

import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Iterator;
import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:nl/liacs/subdisc/PropensityScore.class */
public class PropensityScore {
    public static final String BAYES_RULE = "BayesRule";
    public static final String LOGISTIC_REGRESSION = "LogisticRegression";
    private final String itsMethod;
    private final BitSet itsTarget;
    private final LocalKnowledge itsLocalKn;
    private final GlobalKnowledge itsGlobalKn;
    private final Subgroup itsSubgroup;
    private final double[] itsPropensityScore;
    private final double itsPropensityScoreSum;

    public PropensityScore(Subgroup subgroup, BitSet bitSet, LocalKnowledge localKnowledge, GlobalKnowledge globalKnowledge, String str) {
        this.itsMethod = str;
        this.itsTarget = bitSet;
        this.itsLocalKn = localKnowledge;
        this.itsGlobalKn = globalKnowledge;
        this.itsSubgroup = subgroup;
        this.itsPropensityScore = new double[this.itsTarget.size()];
        if (BAYES_RULE.equals(this.itsMethod)) {
            calculateBayesRule();
        } else {
            if (!LOGISTIC_REGRESSION.equals(this.itsMethod)) {
                this.itsPropensityScoreSum = Double.NaN;
                return;
            }
            calculateLogisticRegression();
        }
        double d = 0.0d;
        for (double d2 : this.itsPropensityScore) {
            d += d2;
        }
        this.itsPropensityScoreSum = d;
    }

    private void calculateBayesRule() {
        HashSet<StatisticsBayesRule> hashSet = new HashSet();
        hashSet.addAll(this.itsGlobalKn.getStatisticsBayesRule());
        hashSet.addAll(this.itsLocalKn.getStatisticsBayesRule(this.itsSubgroup));
        System.out.println("Size overlapping subgroups:");
        System.out.println(hashSet.size());
        int size = this.itsTarget.size();
        double cardinality = this.itsTarget.cardinality() / size;
        double d = 1.0d - cardinality;
        if (hashSet.isEmpty()) {
            System.out.println("Knowledge variables are empty!!!");
            for (int i = 0; i < size; i++) {
                this.itsPropensityScore[i] = cardinality;
            }
            return;
        }
        double[] dArr = new double[size];
        Arrays.fill(dArr, cardinality);
        double[] dArr2 = new double[size];
        Arrays.fill(dArr2, d);
        for (StatisticsBayesRule statisticsBayesRule : hashSet) {
            double[] probabilitiesDataPXGivenT = statisticsBayesRule.getProbabilitiesDataPXGivenT();
            double[] probabilitiesDataPXGivenT2 = statisticsBayesRule.getProbabilitiesDataPXGivenT();
            for (int i2 = 0; i2 < size; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] * probabilitiesDataPXGivenT[i2];
                int i4 = i2;
                dArr2[i4] = dArr2[i4] * probabilitiesDataPXGivenT2[i2];
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            this.itsPropensityScore[i5] = dArr[i5] / (dArr[i5] + dArr2[i5]);
        }
    }

    private void calculateLogisticRegression() {
        HashSet hashSet = new HashSet();
        hashSet.addAll(this.itsGlobalKn.getBitSets());
        hashSet.addAll(this.itsLocalKn.getBitSets(this.itsSubgroup));
        FastVector fastVector = new FastVector(hashSet.size() + 1);
        int i = 0;
        int size = hashSet.size();
        while (i < size) {
            i++;
            fastVector.addElement(new Attribute(Double.toString(i)));
        }
        FastVector fastVector2 = new FastVector(2);
        fastVector2.addElement("0");
        fastVector2.addElement("1");
        fastVector.addElement(new Attribute("target", fastVector2));
        int size2 = this.itsTarget.size();
        int size3 = fastVector.size();
        Instances instances = new Instances("explanatoryVariables", fastVector, size2);
        for (int i2 = 0; i2 < size2; i2++) {
            Instance instance = new Instance(size3);
            instance.setDataset(instances);
            int i3 = 0;
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                instance.setValue(i3, ((BitSet) it.next()).get(i2) ? 1 : 0);
                i3++;
            }
            instance.setValue(i3, this.itsTarget.get(i2) ? 1 : 0);
            instances.add(instance);
        }
        instances.setClassIndex(instances.numAttributes() - 1);
        try {
            logisticClassification(instances, this.itsPropensityScore);
        } catch (Exception e) {
            Log.logCommandLine("logistic classification failed");
            e.printStackTrace();
        }
        Log.logCommandLine("propensity score filled");
    }

    private static void logisticClassification(Instances instances, double[] dArr) throws Exception {
        Logistic logistic = new Logistic();
        logistic.setRidge(0.0d);
        logistic.buildClassifier(instances);
        Log.logCommandLine("Logistic Regression model created");
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = logistic.distributionForInstance(instances.instance(i))[1];
        }
    }

    @Deprecated
    public double[] getPropensityScore() {
        return this.itsPropensityScore;
    }

    public double getPropensityScoreSum() {
        return this.itsPropensityScoreSum;
    }
}
