Skip to content

Commit 0359bb4

Browse files
fealholeix28csala
authored
V0.3.0 (#112)
* relative import; warning filter; data sampler file * merge cond sampler and sampler to one data sampler * change code to 2 space indent * define util classes and rename variables. * refactoring transformer and fix bugs. * rename file * rename func * add hyper parameters to args. * add doc strings. * fix line length * fix bug * fix indent * change indent to 4 * we should allow breaking lines before binary operators. * Bump version: 0.2.2.dev1 → 0.3.0.dev0 * Code refactoring * Removes attr and load/save, and fix lint * Fix typo * Fix conda version * Add TVAE (#111) * Adds tvae * Correctly adds tvae * Restructure files * Simplify tvae test * Fix lint/add verbose to ctgan * Fix lint * Move epochs from fit to __init__ * Fix epochs relocation * General refactoring * Fix readme * Adds save/laod to base class * Fixes save/load, adds test case * Fix lint * Fix lint * Improved testing * Added FutureWarning * Updated warning/fix lint * Fixes tvae bug/fix lint * Empty commit * Fix lint * Update readme * Updates readme * Updates readme Co-authored-by: Lei Xu <leix@mit.edu> Co-authored-by: Carles Sala <carles@pythiac.com>
1 parent 333aa9d commit 0359bb4

20 files changed

Lines changed: 1173 additions & 969 deletions

README.md

Lines changed: 38 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
<p align="left">
2-
<img width=15% src="https://dai.lids.mit.edu/wp-content/uploads/2018/06/Logo_DAI_highres.png" alt=“sdv-dev” />
3-
<i>An open source project from Data to AI Lab at MIT.</i>
2+
<a href="https://dai.lids.mit.edu">
3+
<img width=15% src="https://dai.lids.mit.edu/wp-content/uploads/2018/06/Logo_DAI_highres.png" alt="DAI-Lab" />
4+
</a>
5+
<i>An Open Source Project from the <a href="https://dai.lids.mit.edu">Data to AI Lab, at MIT</a></i>
46
</p>
57

68
[![Development Status](https://img.shields.io/badge/Development%20Status-2%20--%20Pre--Alpha-yellow)](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
@@ -9,29 +11,22 @@
911
[![Downloads](https://pepy.tech/badge/ctgan)](https://pepy.tech/project/ctgan)
1012
[![Coverage Status](https://codecov.io/gh/sdv-dev/CTGAN/branch/master/graph/badge.svg)](https://codecov.io/gh/sdv-dev/CTGAN)
1113

12-
# CTGAN
1314

14-
Implementation of our NeurIPS paper [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503).
15-
16-
CTGAN is a GAN-based data synthesizer that can generate synthetic tabular data with high fidelity.
15+
<img align="center" width=30% src="docs/images/ctgan.png">
1716

17+
* Website: https://sdv.dev
18+
* Documentation: https://sdv.dev/SDV
19+
* Repository: https://github.com/sdv-dev/CTGAN
1820
* License: [MIT](https://github.com/sdv-dev/CTGAN/blob/master/LICENSE)
1921
* Development Status: [Pre-Alpha](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
20-
* Homepage: https://github.com/sdv-dev/CTGAN
2122

2223
## Overview
2324

24-
Based on previous work ([TGAN](https://github.com/sdv-dev/TGAN)) on synthetic data generation,
25-
we develop a new model called CTGAN. Several major differences make CTGAN outperform TGAN.
26-
27-
- **Preprocessing**: CTGAN uses more sophisticated Variational Gaussian Mixture Model to detect
28-
modes of continuous columns.
29-
- **Network structure**: TGAN uses LSTM to generate synthetic data column by column. CTGAN uses
30-
Fully-connected networks which is more efficient.
31-
- **Features to prevent mode collapse**: We design a conditional generator and resample the
32-
training data to prevent model collapse on discrete columns. We use WGANGP and PacGAN to
33-
stabilize the training of GAN.
25+
CTGAN is a collection of Deep Learning based Synthetic Data Generators for single table data, which are able to learn from real data and generate synthetic clones with high fidelity.
3426

27+
Currently, this library implements the **CTGAN** and **TVAE** models proposed in the [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503) paper. For more information about these models, please check out the respective user guides:
28+
* [CTGAN User Guide](https://sdv.dev/SDV/user_guides/single_table/ctgan.html).
29+
* [TVAE User Guide](https://sdv.dev/SDV/user_guides/single_table/tvae.html).
3530

3631
# Install
3732

@@ -49,9 +44,6 @@ pip install ctgan
4944

5045
This will pull and install the latest stable release from [PyPI](https://pypi.org/).
5146

52-
If you want to install from source or contribute to the project please read the
53-
[Contributing Guide](CONTRIBUTING.rst).
54-
5547
## Install with conda
5648

5749
**CTGAN** can also be installed using [conda](https://docs.conda.io/en/latest/):
@@ -63,72 +55,25 @@ conda install -c sdv-dev -c pytorch -c conda-forge ctgan
6355
This will pull and install the latest stable release from [Anaconda](https://anaconda.org/).
6456

6557

66-
# Data Format
67-
68-
**CTGAN** expects the input data to be a table given as either a `numpy.ndarray` or a
69-
`pandas.DataFrame` object with two types of columns:
70-
71-
* **Continuous Columns**: Columns that contain numerical values and which can take any value.
72-
* **Discrete columns**: Columns that only contain a finite number of possible values, wether
73-
these are string values or not.
74-
75-
This is an example of a table with 4 columns:
76-
77-
* A continuous column with float values
78-
* A continuous column with integer values
79-
* A discrete column with string values
80-
* A discrete column with integer values
81-
82-
| | A | B | C | D |
83-
|---|------|-----|-----|---|
84-
| 0 | 0.1 | 100 | 'a' | 1 |
85-
| 1 | -1.3 | 28 | 'b' | 2 |
86-
| 2 | 0.3 | 14 | 'a' | 2 |
87-
| 3 | 1.4 | 87 | 'a' | 3 |
88-
| 4 | -0.1 | 69 | 'b' | 2 |
58+
# Usage Example
8959

60+
> :warning: **WARNING**: If you're just getting started with synthetic data, we recommend using the SDV library which provides user-friendly APIs for interacting with CTGAN. To learn more about using CTGAN through SDV, check out the user guide [here](https://sdv.dev/SDV/user_guides/single_table/ctgan.html).
9061
91-
**NOTE**: CTGAN does not distinguish between float and integer columns, which means that it will
92-
sample float values in all cases. If integer values are required, the outputted float values
93-
must be rounded to integers in a later step, outside of CTGAN.
62+
To get started with CTGAN, you should prepare your data as either a `numpy.ndarray` or a `pandas.DataFrame` object with two types of columns:
9463

95-
# Python Quickstart
64+
* **Continuous Columns**: can contain any numerical value.
65+
* **Discrete Columns**: contain a finite number values, whether these are string values or not.
9666

97-
In this short tutorial we will guide you through a series of steps that will help you
98-
getting started with **CTGAN**.
67+
In this example we load the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult) which is a built-in demo dataset. We then model it using the **CTGANSynthesizer** and generate a synthetic copy of it.
9968

100-
## 1. Model the data
101-
102-
### Step 1: Prepare your data
103-
104-
Before being able to use CTGAN you will need to prepare your data as specified above.
105-
106-
For this example, we will be loading some data using the `ctgan.load_demo` function.
10769

10870
```python3
71+
from ctgan import CTGANSynthesizer
10972
from ctgan import load_demo
11073

11174
data = load_demo()
112-
```
113-
114-
This will download a copy of the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult) as a dataframe:
115-
116-
| age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
117-
|-------|------------------|----------|-----|------------------|------------------|----------|
118-
| 39 | State-gov | 77516 | ... | 40 | United-States | <=50K |
119-
| 50 | Self-emp-not-inc | 83311 | ... | 13 | United-States | <=50K |
120-
| 38 | Private | 215646 | ... | 40 | United-States | <=50K |
121-
| 53 | Private | 234721 | ... | 40 | United-States | <=50K |
122-
| 28 | Private | 338409 | ... | 40 | Cuba | <=50K |
123-
| ... | ... | ... | ... | ... | ... | ... |
12475

125-
126-
Aside from the table itself, you will need to create a list with the names of the discrete
127-
variables.
128-
129-
For this example:
130-
131-
```python3
76+
# Names of the columns that are discrete
13277
discrete_columns = [
13378
'workclass',
13479
'education',
@@ -140,93 +85,23 @@ discrete_columns = [
14085
'native-country',
14186
'income'
14287
]
143-
```
144-
145-
### Step 2: Fit CTGAN to your data
146-
147-
Once you have the data ready, you need to import and create an instance of the `CTGANSynthesizer`
148-
class.
14988

150-
```python3
151-
from ctgan import CTGANSynthesizer
152-
153-
ctgan = CTGANSynthesizer()
154-
```
155-
156-
And then call its `fit` method passing your data and the list of discrete columns
157-
158-
```python
89+
ctgan = CTGANSynthesizer(epochs=10)
15990
ctgan.fit(data, discrete_columns)
160-
```
16191

162-
**NOTE**: This process is likely to take a long time to run.
163-
164-
If you want to make the process shorter, or longer, you can control the number of training epochs
165-
that the model will be performing by adding it to the `fit` call:
166-
167-
```python3
168-
ctgan.fit(data, discrete_columns, epochs=5)
169-
```
170-
171-
## 2. Generate synthetic data
172-
173-
Once the process has finished, all you need to do is call the `sample` method of your
174-
`CTGANSynthesizer` instance indicating the number of rows that you want to generate.
175-
176-
```python3
92+
# Synthetic copy
17793
samples = ctgan.sample(1000)
17894
```
17995

180-
The output will be a table with the exact same format as the input and filled with the synthetic
181-
data generated by the model.
182-
183-
| age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
184-
|---------|--------------|-----------|-----|------------------|------------------|----------|
185-
| 26.3191 | Private | 124079 | ... | 40.1557 | United-States | <=50K |
186-
| 39.8558 | Private | 133996 | ... | 40.2507 | United-States | <=50K |
187-
| 38.2477 | Self-emp-inc | 135955 | ... | 40.1124 | Ecuador | <=50K |
188-
| 29.6468 | Private | 3331.86 | ... | 27.012 | United-States | <=50K |
189-
| 20.9853 | Private | 120637 | ... | 40.0238 | United-States | <=50K |
190-
| ... | ... | ... | ... | ... | ... | ... |
191-
192-
## 3. Generate synthetic data conditioning on one column
193-
194-
In the CTGAN model, we have a conditional vector. By setting the conditional vector, we increase
195-
the probability of getting one value in one discrete column.
196-
197-
For example, the following code **increase the probability** of workclass = " Private".
198-
199-
```python3
200-
samples = ctgan.sample(1000, 'workclass', ' Private')
201-
```
202-
203-
**Note that this code does not guarante workclass=" Private"**
20496

205-
## 4. Save and load the synthesizer
20697

207-
To save a trained ctgan synthesizer, you can call the `save` method passing a path to the file
208-
in which the model will be saved:
209-
210-
```python3
211-
ctgan.save('ctgan.pkl')
212-
```
213-
214-
Later on, you can restore the saved synthetsizer by passing the path to the `load`
215-
model of the `CTGANSynthetizer` method:
98+
# Join our community
21699

217-
```python3
218-
ctgan = CTGANSynthesizer.load('ctgan.pkl')
219-
```
220100

221-
# Join our community
101+
1. Please have a look at the [Contributing Guide](https://sdv.dev/SDV/developer_guides/contributing.html) to see how you can contribute to the project.
102+
2. If you have any doubts, feature requests or detect an error, please [open an issue on github](https://github.com/sdv-dev/CTGAN/issues) or [join our Slack Workspace](https://sdv-space.slack.com/join/shared_invite/zt-gdsfcb5w-0QQpFMVoyB2Yd6SRiMplcw#/).
103+
3. Also, do not forget to check the [project documentation site](https://sdv.dev/SDV/)!
222104

223-
1. If you would like to try more dataset examples, please have a look at the [examples folder](
224-
https://github.com/sdv-dev/CTGAN/tree/master/examples) of the repository. Please contact us
225-
if you have a usage example that you would want to share with the community.
226-
2. If you want to contribute to the project code, please head to the [Contributing Guide](
227-
CONTRIBUTING.rst) for more details about how to do it.
228-
3. If you have any doubts, feature requests or detect an error, please [open an issue on github](
229-
https://github.com/sdv-dev/CTGAN/issues)
230105

231106
# Citing TGAN
232107

@@ -260,3 +135,15 @@ A package to easily deploy **CTGAN** onto a remote server. This package is devel
260135

261136
More details can be found in the corresponding repository: https://github.com/oregonpillow/ctgan-server-cli
262137

138+
139+
# The Synthetic Data Vault
140+
141+
<p>
142+
<a href="https://sdv.dev">
143+
<img width=30% src="https://github.com/sdv-dev/SDV/blob/master/docs/images/SDV-Logo-Color-Tagline.png?raw=true">
144+
</a>
145+
<p><i>This repository is part of <a href="https://sdv.dev">The Synthetic Data Vault Project</a></i></p>
146+
</p>
147+
148+
* Website: https://sdv.dev
149+
* Documentation: https://sdv.dev/SDV

conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{% set name = 'ctgan' %}
2-
{% set version = '0.2.3.dev0' %}
2+
{% set version = '0.3.0.dev0' %}
33

44
package:
55
name: "{{ name|lower }}"

ctgan/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
__author__ = 'MIT Data To AI Lab'
66
__email__ = 'dailabmit@gmail.com'
7-
__version__ = '0.2.3.dev0'
7+
__version__ = '0.3.0.dev0'
88

99
from ctgan.demo import load_demo
10-
from ctgan.synthesizer import CTGANSynthesizer
10+
from ctgan.synthesizers.ctgan import CTGANSynthesizer
11+
from ctgan.synthesizers.tvae import TVAESynthesizer
1112

1213
__all__ = (
1314
'CTGANSynthesizer',
15+
'TVAESynthesizer',
1416
'load_demo'
1517
)

ctgan/__main__.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22

33
from ctgan.data import read_csv, read_tsv, write_tsv
4-
from ctgan.synthesizer import CTGANSynthesizer
4+
from ctgan.synthesizers.ctgan import CTGANSynthesizer
55

66

77
def _parse_args():
@@ -15,11 +15,31 @@ def _parse_args():
1515

1616
parser.add_argument('-m', '--metadata', help='Path to the metadata')
1717
parser.add_argument('-d', '--discrete',
18-
help='Comma separated list of discrete columns, no whitespaces')
19-
18+
help='Comma separated list of discrete columns without whitespaces.')
2019
parser.add_argument('-n', '--num-samples', type=int,
2120
help='Number of rows to sample. Defaults to the training data size')
2221

22+
parser.add_argument('--generator_lr', type=float, default=2e-4,
23+
help='Learning rate for the generator.')
24+
parser.add_argument('--discriminator_lr', type=float, default=2e-4,
25+
help='Learning rate for the discriminator.')
26+
27+
parser.add_argument('--generator_decay', type=float, default=1e-6,
28+
help='Weight decay for the generator.')
29+
parser.add_argument('--discriminator_decay', type=float, default=0,
30+
help='Weight decay for the discriminator.')
31+
32+
parser.add_argument('--embedding_dim', type=int, default=128,
33+
help='Dimension of input z to the generator.')
34+
parser.add_argument('--generator_dim', type=str, default='256,256',
35+
help='Dimension of each generator layer. '
36+
'Comma separated integers with no whitespaces.')
37+
parser.add_argument('--discriminator_dim', type=str, default='256,256',
38+
help='Dimension of each discriminator layer. '
39+
'Comma separated integers with no whitespaces.')
40+
41+
parser.add_argument('--batch_size', type=int, default=500,
42+
help='Batch size. Must be an even number.')
2343
parser.add_argument('--save', default=None, type=str,
2444
help='A filename to save the trained synthesizer.')
2545
parser.add_argument('--load', default=None, type=str,
@@ -38,7 +58,6 @@ def _parse_args():
3858

3959
def main():
4060
args = _parse_args()
41-
4261
if args.tsv:
4362
data, discrete_columns = read_tsv(args.data, args.metadata)
4463
else:
@@ -47,8 +66,15 @@ def main():
4766
if args.load:
4867
model = CTGANSynthesizer.load(args.load)
4968
else:
50-
model = CTGANSynthesizer()
51-
model.fit(data, discrete_columns, args.epochs)
69+
generator_dims = [int(x) for x in args.generator_dims.split(',')]
70+
discriminator_dims = [int(x) for x in args.discriminator_dims.split(',')]
71+
model = CTGANSynthesizer(
72+
embedding_dim=args.embedding_dim, generator_dims=generator_dims,
73+
discriminator_dims=discriminator_dims, generator_lr=args.generator_lr,
74+
generator_decay=args.generator_decay, discriminator_lr=args.discriminator_lr,
75+
discriminator_decay=args.discriminator_decay, batch_size=args.batch_size,
76+
epochs=args.epochs)
77+
model.fit(data, discrete_columns)
5278

5379
if args.save is not None:
5480
model.save(args.save)

0 commit comments

Comments
 (0)