Synthetic data with SDV and CopulaGAN

data.head()
Pclass Sex Age SibSp Parch Fare
0 3 1 22.0 1 0 7.2500
1 1 0 38.0 1 0 71.2833
2 3 0 26.0 0 0 7.9250
3 1 0 35.0 1 0 53.1000
4 3 1 35.0 0 0 8.0500
##    Pclass  Sex   Age  SibSp  Parch     Fare
## 0       3    1  22.0      1      0   7.2500
## 1       1    0  38.0      1      0  71.2833
## 2       3    0  26.0      0      0   7.9250
## 3       1    0  35.0      1      0  53.1000
## 4       3    1  35.0      0      0   8.0500
data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 891.000000 891.000000 891.000000 891.000000 891.000000 891.000000
mean 2.308642 0.647587 29.758889 0.523008 0.381594 32.204208
std 0.836071 0.477990 13.002570 1.102743 0.806057 49.693429
min 1.000000 0.000000 0.420000 0.000000 0.000000 0.000000
25% 2.000000 0.000000 22.000000 0.000000 0.000000 7.910400
50% 3.000000 1.000000 30.000000 0.000000 0.000000 14.454200
75% 3.000000 1.000000 35.000000 1.000000 0.000000 31.000000
max 3.000000 1.000000 80.000000 8.000000 6.000000 512.329200
##            Pclass         Sex         Age       SibSp       Parch        Fare
## count  891.000000  891.000000  891.000000  891.000000  891.000000  891.000000
## mean     2.308642    0.647587   29.758889    0.523008    0.381594   32.204208
## std      0.836071    0.477990   13.002570    1.102743    0.806057   49.693429
## min      1.000000    0.000000    0.420000    0.000000    0.000000    0.000000
## 25%      2.000000    0.000000   22.000000    0.000000    0.000000    7.910400
## 50%      3.000000    1.000000   30.000000    0.000000    0.000000   14.454200
## 75%      3.000000    1.000000   35.000000    1.000000    0.000000   31.000000
## max      3.000000    1.000000   80.000000    8.000000    6.000000  512.329200
from sdv.tabular import CopulaGAN
model = CopulaGAN()
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 1 0 23.113076 1 0 29.937752
1 3 1 29.918377 1 0 5.958775
2 1 1 76.710166 0 0 202.889201
3 3 0 46.581959 1 1 89.833194
4 1 0 33.135297 1 2 118.551008
##    Pclass  Sex        Age  SibSp  Parch        Fare
## 0       1    0  11.797730      3      2   15.963036
## 1       2    0  28.135436      2      2   12.701402
## 2       3    0  26.367627      0      1   10.742400
## 3       3    1  24.116738      0      1  144.611161
## 4       2    0  31.368429      3      1   13.659402
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.00000 200.000000 200.000000
mean 1.960000 0.540000 37.131076 1.10500 0.975000 69.046022
std 0.831781 0.499648 17.131095 1.03408 1.196467 70.354667
min 1.000000 0.000000 3.409883 0.00000 0.000000 3.135404
25% 1.000000 0.000000 27.213419 0.00000 0.000000 10.772323
50% 2.000000 1.000000 31.788050 1.00000 0.000000 42.555842
75% 3.000000 1.000000 49.398770 2.00000 2.000000 105.964913
max 3.000000 1.000000 78.841709 5.00000 5.000000 294.906651
##            Pclass         Sex         Age       SibSp       Parch        Fare
## count  200.000000  200.000000  200.000000  200.000000  200.000000  200.000000
## mean     1.990000    0.440000   31.717758    0.945000    1.240000   76.998168
## std      0.795654    0.497633   14.745161    0.925341    1.487773   82.340578
## min      1.000000    0.000000   -0.299758    0.000000    0.000000    4.534767
## 25%      1.000000    0.000000   23.051398    0.000000    0.000000   14.265699
## 50%      2.000000    0.000000   30.367405    1.000000    1.000000   36.297745
## 75%      3.000000    1.000000   40.796266    1.000000    2.000000  115.727641
## max      3.000000    1.000000   77.002768    5.000000    6.000000  350.698375
from sdv.evaluation import evaluate

evaluate(new_data, data)
0.2793053940561327
## 0.4487854281948458
model = CopulaGAN(
    field_transformers={
        'Pclass': 'categorical',
        'Sex': 'categorical',
        'Age': 'float',
        'SibSp': 'boolean',
        'Parch': 'integer',
        'Fare': 'float'
    },
    field_distributions={
        'Fare': 'truncated_gaussian'
    }
)
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 2 1 55.113425 1 0 -9.755683e-08
1 2 1 42.257149 1 2 7.252636e+00
2 3 1 30.128989 1 0 7.708430e+00
3 2 1 30.838961 1 0 -1.184872e-07
4 3 1 30.370534 1 0 3.530599e+00
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.000000 200.000000 2.000000e+02
mean 2.265000 0.630000 33.075478 0.715000 0.915000 2.587770e+01
std 0.726231 0.484016 13.472204 0.452547 1.189474 3.599977e+01
min 1.000000 0.000000 -1.122344 0.000000 0.000000 -1.191555e-07
25% 2.000000 0.000000 26.064943 0.000000 0.000000 6.589631e+00
50% 2.000000 1.000000 30.706679 1.000000 0.000000 1.182216e+01
75% 3.000000 1.000000 41.977015 1.000000 2.000000 2.797455e+01
max 3.000000 1.000000 74.046067 1.000000 5.000000 2.014112e+02
evaluate(new_data, data)
0.3533601962024023