Synthetic data with SDV and CTGAN

import pandas as pd
import warnings

warnings.filterwarnings('ignore')

data = pd.read_csv("data/svm-hyperparameters-train-features.csv")
data.head()
Pclass Sex Age SibSp Parch Fare
0 3 1 22.0 1 0 7.2500
1 1 0 38.0 1 0 71.2833
2 3 0 26.0 0 0 7.9250
3 1 0 35.0 1 0 53.1000
4 3 1 35.0 0 0 8.0500
data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 891.000000 891.000000 891.000000 891.000000 891.000000 891.000000
mean 2.308642 0.647587 29.758889 0.523008 0.381594 32.204208
std 0.836071 0.477990 13.002570 1.102743 0.806057 49.693429
min 1.000000 0.000000 0.420000 0.000000 0.000000 0.000000
25% 2.000000 0.000000 22.000000 0.000000 0.000000 7.910400
50% 3.000000 1.000000 30.000000 0.000000 0.000000 14.454200
75% 3.000000 1.000000 35.000000 1.000000 0.000000 31.000000
max 3.000000 1.000000 80.000000 8.000000 6.000000 512.329200
from sdv.tabular import CTGAN
model = CTGAN()
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 2 0 2.842574 6 2 28.927918
1 2 0 47.380061 1 1 122.939126
2 2 0 42.536188 0 0 36.907182
3 1 0 28.853204 2 2 37.291651
4 0 0 35.857498 3 1 77.988031
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.000000 200.000000 200.000000
mean 1.310000 0.255000 27.552343 1.170000 1.250000 104.137308
std 0.858829 0.436955 14.735829 1.730919 1.399031 137.597823
min 0.000000 0.000000 -6.115006 0.000000 0.000000 -85.821346
25% 1.000000 0.000000 19.907786 0.000000 0.000000 27.131476
50% 1.000000 0.000000 29.180020 0.000000 1.000000 56.792348
75% 2.000000 1.000000 36.793752 2.000000 2.000000 102.704251
max 3.000000 1.000000 64.004638 8.000000 6.000000 747.924987
from sdv.evaluation import evaluate

evaluate(new_data, data)
0.19950789269381727
model = CTGAN(
    epochs=500,
    batch_size=100,
    generator_dim=(256, 256, 256),
    discriminator_dim=(256, 256, 256)
)
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 1 0 29.551254 0 0 257.363295
1 1 1 49.643384 1 -1 53.014717
2 0 0 -14.547607 0 2 0.694069
3 3 0 7.065562 5 0 19.361019
4 3 0 28.684305 0 2 -256.489558
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.000000 200.000000 200.000000
mean 1.330000 0.215000 28.184757 1.200000 1.015000 153.406430
std 0.880327 0.411853 24.932518 2.682532 1.737779 278.082600
min 0.000000 0.000000 -17.001736 -4.000000 -1.000000 -433.103355
25% 1.000000 0.000000 12.779717 0.000000 0.000000 13.491768
50% 1.000000 0.000000 28.262440 0.000000 0.000000 45.898936
75% 2.000000 0.000000 34.990247 1.000000 1.000000 215.423687
max 3.000000 1.000000 93.451393 13.000000 8.000000 1104.455566
evaluate(new_data, data)
0.20902599406872358