Counterfactual Fairness in Java

Here we will look at how to build a counterfactually fair model, as detailed in Counterfactual Fairness, specifically the "Fair Add" model.

This implementation will rely mostly on Apache Commons Math1 linear regression implementations, namely the Ordinary Least Squares (OLS) regression2. We start then by adding the relevant Maven dependencies:

<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.6.1</version>
</dependency>

Data will be passed as a RealMatrix3. This matrix will have dimensions \(N\times f\), where \(N\) is the number of observations and \(f\) is the number of features.

We can instatiate the model using

// RealMatrix data = ...
final CounterfactuallyFairModel model = new CounterfactuallyFairModel(data);

We will then need the following information:

Assuming that we have the same variables as in the counterfactual fairness example, let's say that the protected attributes have in the data matrix, column numbers 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 and the model variables (LSAT and UGPA) have indices 1, 0 and the target (ZFYA) has index 2. We then calculate the counterfactually fair model using:

model.calculate(new int[]{5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
             new int[]{1, 0}, 2);

The calculate method performs the following:

public void calculate(int[] protectedIndices, int[] variableIndices, int targetIndex) {  
    final RealMatrix residuals = new
        Array2DRowRealMatrix(this.data.getRowDimension(), 
        variableIndices.length);  

    for (int i = 0; i < variableIndices.length; i++) {  
        final int index = variableIndices[i];  
        final RealVector varResidual = 
            this.calculateEpsilon(protectedIndices, index);  
        residuals.setColumn(i, varResidual.toArray());  
    }  

    // predict target from residuals  
    final OLSMultipleLinearRegression regression = new
        OLSMultipleLinearRegression();  
    regression.newSampleData(this.data.getColumn(targetIndex), 
        residuals.getData());  
}

As in Counterfactual Fairness, we calculate a regression model to predict each of the variable (LSAT and UGPA) using the protected variables. The resulting residuals, \(\epsilon_{LSAT}\) and \(\epsilon_{UGPA}\) will in turn be used to calculate another regression model in order to predict the target variable ZFYA.

The residuals are calculated using the calculateEpsilon method, which consists of:

public RealVector calculateEpsilon(int[] protectedIndices, int targetIndex) {  
    int[] protectedRows = new int[this.data.getRowDimension()];  
    Arrays.setAll(protectedRows, i -> i);  
    final RealMatrix _x = this.data.getSubMatrix(protectedRows,
        protectedIndices);  
    final RealVector _y = this.data.getSubMatrix(protectedRows, 
        new int[]{targetIndex}).getColumnVector(0);  

    final OLSMultipleLinearRegression regression = new
        OLSMultipleLinearRegression();  
    regression.newSampleData(_y.toArray(), _x.getData());  
    return new ArrayRealVector(regression.estimateResiduals());  
}

Which simply calculates a regression model for the variables using the protected attributes and returning a RealVector with the residual \(\epsilon\).