Synthetic data with SDV and CTGAN
Synthetic data with SDV and CTGAN
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv("data/svm-hyperparameters-train-features.csv")data.head()| Pclass | Sex | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|
| 0 | 3 | 1 | 22.0 | 1 | 0 | 7.2500 |
| 1 | 1 | 0 | 38.0 | 1 | 0 | 71.2833 |
| 2 | 3 | 0 | 26.0 | 0 | 0 | 7.9250 |
| 3 | 1 | 0 | 35.0 | 1 | 0 | 53.1000 |
| 4 | 3 | 1 | 35.0 | 0 | 0 | 8.0500 |
data.describe(include='all')| Pclass | Sex | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|
| count | 891.000000 | 891.000000 | 891.000000 | 891.000000 | 891.000000 | 891.000000 |
| mean | 2.308642 | 0.647587 | 29.758889 | 0.523008 | 0.381594 | 32.204208 |
| std | 0.836071 | 0.477990 | 13.002570 | 1.102743 | 0.806057 | 49.693429 |
| min | 1.000000 | 0.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 2.000000 | 0.000000 | 22.000000 | 0.000000 | 0.000000 | 7.910400 |
| 50% | 3.000000 | 1.000000 | 30.000000 | 0.000000 | 0.000000 | 14.454200 |
| 75% | 3.000000 | 1.000000 | 35.000000 | 1.000000 | 0.000000 | 31.000000 |
| max | 3.000000 | 1.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
from sdv.tabular import CTGANmodel = CTGAN()model.fit(data)new_data = model.sample(200)new_data.head()| Pclass | Sex | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|
| 0 | 3 | 1 | 31.40 | 1 | 1 | 24.7142 |
| 1 | 3 | 1 | 63.44 | 1 | 1 | 0.6418 |
| 2 | 3 | 0 | 32.73 | 0 | 0 | 2.8117 |
| 3 | 1 | 0 | 27.53 | 0 | 0 | 40.4747 |
| 4 | 2 | 1 | 46.31 | 2 | 1 | 104.2955 |
new_data.describe(include='all')| Pclass | Sex | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|
| count | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 |
| mean | 2.605000 | 0.535000 | 37.052600 | 0.740000 | 0.450000 | 32.947807 |
| std | 0.715215 | 0.500025 | 15.763543 | 1.212332 | 0.692748 | 56.156945 |
| min | 1.000000 | 0.000000 | 2.870000 | 0.000000 | 0.000000 | 0.000000 |
| 25% | 2.000000 | 0.000000 | 27.217500 | 0.000000 | 0.000000 | 12.590775 |
| 50% | 3.000000 | 1.000000 | 33.100000 | 0.000000 | 0.000000 | 18.435050 |
| 75% | 3.000000 | 1.000000 | 46.692500 | 1.000000 | 1.000000 | 24.083125 |
| max | 3.000000 | 1.000000 | 80.000000 | 6.000000 | 2.000000 | 380.969200 |
from sdv.evaluation import evaluate
evaluate(new_data, data)0.5513349938501996
model = CTGAN(
epochs=500,
batch_size=100,
generator_dim=(256, 256, 256),
discriminator_dim=(256, 256, 256)
)model.fit(data)new_data = model.sample(200)new_data.head()| Pclass | Sex | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|
| 0 | 3 | 1 | 30.69 | 0 | 0 | 13.5556 |
| 1 | 2 | 0 | 0.42 | 1 | 0 | 33.1753 |
| 2 | 3 | 1 | 30.86 | 0 | 0 | 34.4404 |
| 3 | 1 | 1 | 11.07 | 0 | 0 | 29.0068 |
| 4 | 2 | 0 | 54.40 | 0 | 0 | 36.6160 |
new_data.describe(include='all')| Pclass | Sex | Age | SibSp | Parch | Fare | |
|---|---|---|---|---|---|---|
| count | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 |
| mean | 2.350000 | 0.540000 | 36.034900 | 0.365000 | 0.395000 | 34.128625 |
| std | 0.699964 | 0.499648 | 16.254839 | 0.809352 | 0.762813 | 42.191675 |
| min | 1.000000 | 0.000000 | 0.420000 | 0.000000 | 0.000000 | 2.339800 |
| 25% | 2.000000 | 0.000000 | 28.737500 | 0.000000 | 0.000000 | 10.917300 |
| 50% | 2.000000 | 1.000000 | 31.645000 | 0.000000 | 0.000000 | 19.874400 |
| 75% | 3.000000 | 1.000000 | 49.312500 | 1.000000 | 0.000000 | 35.050750 |
| max | 3.000000 | 1.000000 | 80.000000 | 5.000000 | 2.000000 | 269.388000 |
evaluate(new_data, data)0.6065965175235917