Synthetic data with SDV and CTGAN

Synthetic data with SDV and CTGAN

import pandas as pd
import warnings

warnings.filterwarnings('ignore')

data = pd.read_csv("data/svm-hyperparameters-train-features.csv")
data.head()

PclassSexAgeSibSpParchFare
03122.0107.2500
11038.01071.2833
23026.0007.9250
31035.01053.1000
43135.0008.0500
data.describe(include='all')

PclassSexAgeSibSpParchFare
count891.000000891.000000891.000000891.000000891.000000891.000000
mean2.3086420.64758729.7588890.5230080.38159432.204208
std0.8360710.47799013.0025701.1027430.80605749.693429
min1.0000000.0000000.4200000.0000000.0000000.000000
25%2.0000000.00000022.0000000.0000000.0000007.910400
50%3.0000001.00000030.0000000.0000000.00000014.454200
75%3.0000001.00000035.0000001.0000000.00000031.000000
max3.0000001.00000080.0000008.0000006.000000512.329200
from sdv.tabular import CTGAN
model = CTGAN()
model.fit(data)
new_data = model.sample(200)
new_data.head()

PclassSexAgeSibSpParchFare
03131.401124.7142
13163.44110.6418
23032.73002.8117
31027.530040.4747
42146.3121104.2955
new_data.describe(include='all')

PclassSexAgeSibSpParchFare
count200.000000200.000000200.000000200.000000200.000000200.000000
mean2.6050000.53500037.0526000.7400000.45000032.947807
std0.7152150.50002515.7635431.2123320.69274856.156945
min1.0000000.0000002.8700000.0000000.0000000.000000
25%2.0000000.00000027.2175000.0000000.00000012.590775
50%3.0000001.00000033.1000000.0000000.00000018.435050
75%3.0000001.00000046.6925001.0000001.00000024.083125
max3.0000001.00000080.0000006.0000002.000000380.969200
from sdv.evaluation import evaluate

evaluate(new_data, data)
0.5513349938501996
model = CTGAN(
    epochs=500,
    batch_size=100,
    generator_dim=(256, 256, 256),
    discriminator_dim=(256, 256, 256)
)
model.fit(data)
new_data = model.sample(200)
new_data.head()

PclassSexAgeSibSpParchFare
03130.690013.5556
1200.421033.1753
23130.860034.4404
31111.070029.0068
42054.400036.6160
new_data.describe(include='all')

PclassSexAgeSibSpParchFare
count200.000000200.000000200.000000200.000000200.000000200.000000
mean2.3500000.54000036.0349000.3650000.39500034.128625
std0.6999640.49964816.2548390.8093520.76281342.191675
min1.0000000.0000000.4200000.0000000.0000002.339800
25%2.0000000.00000028.7375000.0000000.00000010.917300
50%2.0000001.00000031.6450000.0000000.00000019.874400
75%3.0000001.00000049.3125001.0000000.00000035.050750
max3.0000001.00000080.0000005.0000002.000000269.388000
evaluate(new_data, data)
0.6065965175235917