Synthetic data with SDV and CTGAN

Synthetic data with SDV and CTGAN

import pandas as pd
import warnings

warnings.filterwarnings('ignore')

data = pd.read_csv("data/svm-hyperparameters-train-features.csv")
data.head()
Pclass Sex Age SibSp Parch Fare
0 3 1 22.0 1 0 7.2500
1 1 0 38.0 1 0 71.2833
2 3 0 26.0 0 0 7.9250
3 1 0 35.0 1 0 53.1000
4 3 1 35.0 0 0 8.0500
data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 891.000000 891.000000 891.000000 891.000000 891.000000 891.000000
mean 2.308642 0.647587 29.758889 0.523008 0.381594 32.204208
std 0.836071 0.477990 13.002570 1.102743 0.806057 49.693429
min 1.000000 0.000000 0.420000 0.000000 0.000000 0.000000
25% 2.000000 0.000000 22.000000 0.000000 0.000000 7.910400
50% 3.000000 1.000000 30.000000 0.000000 0.000000 14.454200
75% 3.000000 1.000000 35.000000 1.000000 0.000000 31.000000
max 3.000000 1.000000 80.000000 8.000000 6.000000 512.329200
from sdv.tabular import CTGAN
model = CTGAN()
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 3 1 31.40 1 1 24.7142
1 3 1 63.44 1 1 0.6418
2 3 0 32.73 0 0 2.8117
3 1 0 27.53 0 0 40.4747
4 2 1 46.31 2 1 104.2955
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.000000 200.000000 200.000000
mean 2.605000 0.535000 37.052600 0.740000 0.450000 32.947807
std 0.715215 0.500025 15.763543 1.212332 0.692748 56.156945
min 1.000000 0.000000 2.870000 0.000000 0.000000 0.000000
25% 2.000000 0.000000 27.217500 0.000000 0.000000 12.590775
50% 3.000000 1.000000 33.100000 0.000000 0.000000 18.435050
75% 3.000000 1.000000 46.692500 1.000000 1.000000 24.083125
max 3.000000 1.000000 80.000000 6.000000 2.000000 380.969200
from sdv.evaluation import evaluate

evaluate(new_data, data)
0.5513349938501996
model = CTGAN(
    epochs=500,
    batch_size=100,
    generator_dim=(256, 256, 256),
    discriminator_dim=(256, 256, 256)
)
model.fit(data)
new_data = model.sample(200)
new_data.head()
Pclass Sex Age SibSp Parch Fare
0 3 1 30.69 0 0 13.5556
1 2 0 0.42 1 0 33.1753
2 3 1 30.86 0 0 34.4404
3 1 1 11.07 0 0 29.0068
4 2 0 54.40 0 0 36.6160
new_data.describe(include='all')
Pclass Sex Age SibSp Parch Fare
count 200.000000 200.000000 200.000000 200.000000 200.000000 200.000000
mean 2.350000 0.540000 36.034900 0.365000 0.395000 34.128625
std 0.699964 0.499648 16.254839 0.809352 0.762813 42.191675
min 1.000000 0.000000 0.420000 0.000000 0.000000 2.339800
25% 2.000000 0.000000 28.737500 0.000000 0.000000 10.917300
50% 2.000000 1.000000 31.645000 0.000000 0.000000 19.874400
75% 3.000000 1.000000 49.312500 1.000000 0.000000 35.050750
max 3.000000 1.000000 80.000000 5.000000 2.000000 269.388000
evaluate(new_data, data)
0.6065965175235917