Synthetic data with SDV and CTGAN
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv("data/svm-hyperparameters-train-features.csv")
| Pclass | Sex | Age | SibSp | Parch | Fare |
---|
0 | 3 | 1 | 22.0 | 1 | 0 | 7.2500 |
---|
1 | 1 | 0 | 38.0 | 1 | 0 | 71.2833 |
---|
2 | 3 | 0 | 26.0 | 0 | 0 | 7.9250 |
---|
3 | 1 | 0 | 35.0 | 1 | 0 | 53.1000 |
---|
4 | 3 | 1 | 35.0 | 0 | 0 | 8.0500 |
---|
data.describe(include='all')
| Pclass | Sex | Age | SibSp | Parch | Fare |
---|
count | 891.000000 | 891.000000 | 891.000000 | 891.000000 | 891.000000 | 891.000000 |
---|
mean | 2.308642 | 0.647587 | 29.758889 | 0.523008 | 0.381594 | 32.204208 |
---|
std | 0.836071 | 0.477990 | 13.002570 | 1.102743 | 0.806057 | 49.693429 |
---|
min | 1.000000 | 0.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
---|
25% | 2.000000 | 0.000000 | 22.000000 | 0.000000 | 0.000000 | 7.910400 |
---|
50% | 3.000000 | 1.000000 | 30.000000 | 0.000000 | 0.000000 | 14.454200 |
---|
75% | 3.000000 | 1.000000 | 35.000000 | 1.000000 | 0.000000 | 31.000000 |
---|
max | 3.000000 | 1.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
---|
from sdv.tabular import CTGAN
new_data = model.sample(200)
| Pclass | Sex | Age | SibSp | Parch | Fare |
---|
0 | 3 | 1 | 31.40 | 1 | 1 | 24.7142 |
---|
1 | 3 | 1 | 63.44 | 1 | 1 | 0.6418 |
---|
2 | 3 | 0 | 32.73 | 0 | 0 | 2.8117 |
---|
3 | 1 | 0 | 27.53 | 0 | 0 | 40.4747 |
---|
4 | 2 | 1 | 46.31 | 2 | 1 | 104.2955 |
---|
new_data.describe(include='all')
| Pclass | Sex | Age | SibSp | Parch | Fare |
---|
count | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 |
---|
mean | 2.605000 | 0.535000 | 37.052600 | 0.740000 | 0.450000 | 32.947807 |
---|
std | 0.715215 | 0.500025 | 15.763543 | 1.212332 | 0.692748 | 56.156945 |
---|
min | 1.000000 | 0.000000 | 2.870000 | 0.000000 | 0.000000 | 0.000000 |
---|
25% | 2.000000 | 0.000000 | 27.217500 | 0.000000 | 0.000000 | 12.590775 |
---|
50% | 3.000000 | 1.000000 | 33.100000 | 0.000000 | 0.000000 | 18.435050 |
---|
75% | 3.000000 | 1.000000 | 46.692500 | 1.000000 | 1.000000 | 24.083125 |
---|
max | 3.000000 | 1.000000 | 80.000000 | 6.000000 | 2.000000 | 380.969200 |
---|
from sdv.evaluation import evaluate
evaluate(new_data, data)
0.5513349938501996
model = CTGAN(
epochs=500,
batch_size=100,
generator_dim=(256, 256, 256),
discriminator_dim=(256, 256, 256)
)
new_data = model.sample(200)
| Pclass | Sex | Age | SibSp | Parch | Fare |
---|
0 | 3 | 1 | 30.69 | 0 | 0 | 13.5556 |
---|
1 | 2 | 0 | 0.42 | 1 | 0 | 33.1753 |
---|
2 | 3 | 1 | 30.86 | 0 | 0 | 34.4404 |
---|
3 | 1 | 1 | 11.07 | 0 | 0 | 29.0068 |
---|
4 | 2 | 0 | 54.40 | 0 | 0 | 36.6160 |
---|
new_data.describe(include='all')
| Pclass | Sex | Age | SibSp | Parch | Fare |
---|
count | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 | 200.000000 |
---|
mean | 2.350000 | 0.540000 | 36.034900 | 0.365000 | 0.395000 | 34.128625 |
---|
std | 0.699964 | 0.499648 | 16.254839 | 0.809352 | 0.762813 | 42.191675 |
---|
min | 1.000000 | 0.000000 | 0.420000 | 0.000000 | 0.000000 | 2.339800 |
---|
25% | 2.000000 | 0.000000 | 28.737500 | 0.000000 | 0.000000 | 10.917300 |
---|
50% | 2.000000 | 1.000000 | 31.645000 | 0.000000 | 0.000000 | 19.874400 |
---|
75% | 3.000000 | 1.000000 | 49.312500 | 1.000000 | 0.000000 | 35.050750 |
---|
max | 3.000000 | 1.000000 | 80.000000 | 5.000000 | 2.000000 | 269.388000 |
---|
0.6065965175235917