Synthetic data with SDV and CopulaGAN

import pandas as pd
import warnings

warnings.filterwarnings('ignore')

data = pd.read_csv("data/svm-hyperparameters-train-features.csv")

data.head()

PclassSexAgeSibSpParchFare
03122.0107.2500
11038.01071.2833
23026.0007.9250
31035.01053.1000
43135.0008.0500
data.describe(include='all')

PclassSexAgeSibSpParchFare
count891.000000891.000000891.000000891.000000891.000000891.000000
mean2.3086420.64758729.7588890.5230080.38159432.204208
std0.8360710.47799013.0025701.1027430.80605749.693429
min1.0000000.0000000.4200000.0000000.0000000.000000
25%2.0000000.00000022.0000000.0000000.0000007.910400
50%3.0000001.00000030.0000000.0000000.00000014.454200
75%3.0000001.00000035.0000001.0000000.00000031.000000
max3.0000001.00000080.0000008.0000006.000000512.329200
from sdv.tabular import CopulaGAN
model = CopulaGAN()
model.fit(data)
new_data = model.sample(200)
new_data.head()

PclassSexAgeSibSpParchFare
03032.020018.2784
13120.9050109.6910
23040.8001179.5139
33033.131017.3447
41120.62009.6040
new_data.describe(include='all')

PclassSexAgeSibSpParchFare
count200.000000200.000000200.000000200.000000200.000000200.000000
mean2.4900000.45000031.4047000.4050000.11500046.756109
std0.8205280.49874212.2977210.9193480.46134552.828023
min1.0000000.0000000.4500000.0000000.0000003.034900
25%2.0000000.00000024.7125000.0000000.00000011.768475
50%3.0000000.00000030.2950000.0000000.00000027.859150
75%3.0000001.00000034.3075001.0000000.00000058.361350
max3.0000001.00000073.3400007.0000003.000000381.859600
from sdv.evaluation import evaluate

evaluate(new_data, data)
0.539229052965023
model = CopulaGAN(
    field_transformers={
        'Pclass': 'categorical',
        'Sex': 'categorical',
        'Age': 'float',
        'SibSp': 'boolean',
        'Parch': 'integer',
        'Fare': 'float'
    },
    field_distributions={
        'Fare': 'truncated_gaussian'
    }
)
model.fit(data)
new_data = model.sample(200)
new_data.head()

PclassSexAgeSibSpParchFare
0101.071013.8318
13015.580046.6937
21126.530053.8841
31029.58005.2646
41030.13005.1415
new_data.describe(include='all')

PclassSexAgeSibSpParchFare
count200.000000200.000000200.000000200.000000200.000000200.000000
mean2.2100000.41000022.9356000.5500000.14500049.155439
std0.9328730.49306813.9551820.4987420.44153142.733997
min1.0000000.0000000.4300000.0000000.0000004.749400
25%1.0000000.00000013.4275000.0000000.00000013.340950
50%3.0000000.00000025.8750001.0000000.00000034.373700
75%3.0000001.00000030.2250001.0000000.00000075.206275
max3.0000001.00000071.3400001.0000003.000000220.761200
evaluate(new_data, data)
0.4983713994365315