Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs for conditional sampling #236

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def is_discrete_column(column_info):
self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim
self._discrete_column_matrix_st[current_id] = st

current_cond_st += span_info.dim
current_id += 1
st = ed
Expand Down Expand Up @@ -150,7 +152,7 @@ def dim_cond_vec(self):
def generate_cond_from_condition_column_info(self, condition_info, batch):
"""Generate the condition vector."""
vec = np.zeros((batch, self._n_categories), dtype='float32')
id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']]
id_ += condition_info['value_id']
vec[:, id_] = 1
return vec
8 changes: 6 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def sample(self, n, condition_column=None, condition_value=None):
Returns:
numpy.ndarray or pandas.DataFrame
"""
self._generator.eval()

if condition_column is not None and condition_value is not None:
condition_info = self._transformer.convert_column_name_value_to_id(
condition_column, condition_value)
Expand Down Expand Up @@ -467,8 +469,10 @@ def sample(self, n, condition_column=None, condition_value=None):
c1 = torch.from_numpy(c1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)

fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
with torch.no_grad():
fake = self._generator(fakez)
fakeact = self._apply_activate(fake)

data.append(fakeact.detach().cpu().numpy())

data = np.concatenate(data, axis=0)
Expand Down
43 changes: 35 additions & 8 deletions tests/integration/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@ def test_log_frequency():

discrete_columns = ['discrete']

ctgan = CTGANSynthesizer(epochs=100)
ctgan = CTGANSynthesizer(epochs=1)
ctgan.fit(data, discrete_columns)

sampled = ctgan.sample(10000)
counts = sampled['discrete'].value_counts()
assert counts['a'] < 6500
assert ctgan._data_sampler._discrete_column_category_prob[0][0] < 0.95
assert ctgan._data_sampler._discrete_column_category_prob[0][1] > 0.025
assert ctgan._data_sampler._discrete_column_category_prob[0][2] > 0.025

ctgan = CTGANSynthesizer(log_frequency=False, epochs=100)
ctgan = CTGANSynthesizer(log_frequency=False, epochs=1)
ctgan.fit(data, discrete_columns)

sampled = ctgan.sample(10000)
counts = sampled['discrete'].value_counts()
assert counts['a'] > 9000
assert ctgan._data_sampler._discrete_column_category_prob[0][0] == 0.95
assert ctgan._data_sampler._discrete_column_category_prob[0][1] == 0.025
assert ctgan._data_sampler._discrete_column_category_prob[0][2] == 0.025


def test_categorical_nan():
Expand Down Expand Up @@ -134,6 +134,33 @@ def test_synthesizer_sample():
assert isinstance(samples, pd.DataFrame)


def test_synthesizer_sampling():
"""Test the CTGANSynthesizer sampling."""
data = pd.DataFrame({
'continuous': np.random.random(1000),
'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25])
})

discrete_columns = ['discrete']

ctgan = CTGANSynthesizer(epochs=100)
ctgan.fit(data, discrete_columns)

samples = ctgan.sample(1000)
assert samples['discrete'].value_counts()['a'] > 800
assert samples['discrete'].value_counts()['b'] < 100
assert samples['discrete'].value_counts()['c'] < 100

samples = ctgan.sample(1000, condition_column='discrete', condition_value='a')
assert samples['discrete'].value_counts()['a'] > 750

samples = ctgan.sample(1000, condition_column='discrete', condition_value='b')
assert samples['discrete'].value_counts()['b'] > 750

samples = ctgan.sample(1000, condition_column='discrete', condition_value='c')
assert samples['discrete'].value_counts()['c'] > 750


def test_save_load():
"""Test the CTGANSynthesizer load/save methods."""
data = pd.DataFrame({
Expand Down