Synthetic data with SDV and CopulaGAN

import pandas as pd
import warnings

warnings.filterwarnings('ignore')

data = pd.read_csv("data/svm-hyperparameters-train-features.csv")
data.head()
Pclass Sex Age SibSp Parch Fare
0 3 1 22.0 1 0 7.2500
1 1 0 38.0 1 0 71.2833
2 3 0 26.0 0 0 7.9250
3 1 0 35.0 1 0 53.1000
4 3 1 35.0 0 0 8.0500
data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 891.000000 891.000000 891.000000 891.000000 891.000000 891.000000
mean 2.308642 0.647587 29.758889 0.523008 0.381594 32.204208
std 0.836071 0.477990 13.002570 1.102743 0.806057 49.693429
min 1.000000 0.000000 0.420000 0.000000 0.000000 0.000000
25% 2.000000 0.000000 22.000000 0.000000 0.000000 7.910400
50% 3.000000 1.000000 30.000000 0.000000 0.000000 14.454200
75% 3.000000 1.000000 35.000000 1.000000 0.000000 31.000000
max 3.000000 1.000000 80.000000 8.000000 6.000000 512.329200
from sdv.tabular import CopulaGAN
model = CopulaGAN()
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 1 0 23.113076 1 0 29.937752
1 3 1 29.918377 1 0 5.958775
2 1 1 76.710166 0 0 202.889201
3 3 0 46.581959 1 1 89.833194
4 1 0 33.135297 1 2 118.551008
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.00000 200.000000 200.000000
mean 1.960000 0.540000 37.131076 1.10500 0.975000 69.046022
std 0.831781 0.499648 17.131095 1.03408 1.196467 70.354667
min 1.000000 0.000000 3.409883 0.00000 0.000000 3.135404
25% 1.000000 0.000000 27.213419 0.00000 0.000000 10.772323
50% 2.000000 1.000000 31.788050 1.00000 0.000000 42.555842
75% 3.000000 1.000000 49.398770 2.00000 2.000000 105.964913
max 3.000000 1.000000 78.841709 5.00000 5.000000 294.906651
from sdv.evaluation import evaluate

evaluate(new_data, data)
0.2793053940561327
model = CopulaGAN(
    field_transformers={
        'Pclass': 'categorical',
        'Sex': 'categorical',
        'Age': 'float',
        'SibSp': 'boolean',
        'Parch': 'integer',
        'Fare': 'float'
    },
    field_distributions={
        'Fare': 'truncated_gaussian'
    }
)
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 2 1 55.113425 1 0 -9.755683e-08
1 2 1 42.257149 1 2 7.252636e+00
2 3 1 30.128989 1 0 7.708430e+00
3 2 1 30.838961 1 0 -1.184872e-07
4 3 1 30.370534 1 0 3.530599e+00
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.000000 200.000000 2.000000e+02
mean 2.265000 0.630000 33.075478 0.715000 0.915000 2.587770e+01
std 0.726231 0.484016 13.472204 0.452547 1.189474 3.599977e+01
min 1.000000 0.000000 -1.122344 0.000000 0.000000 -1.191555e-07
25% 2.000000 0.000000 26.064943 0.000000 0.000000 6.589631e+00
50% 2.000000 1.000000 30.706679 1.000000 0.000000 1.182216e+01
75% 3.000000 1.000000 41.977015 1.000000 2.000000 2.797455e+01
max 3.000000 1.000000 74.046067 1.000000 5.000000 2.014112e+02
evaluate(new_data, data)
0.3533601962024023