Skip to content

Commit ca53aaa

Browse files
author
Francisco Santos
committed
PATEGAN base implementation
Remove duplicate test files after renaming Use BaseModel variables
1 parent d888bcf commit ca53aaa

File tree

4 files changed

+259
-0
lines changed

4 files changed

+259
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ pmlb==1.0.*
99
tqdm<5.0
1010
typeguard==2.13.*
1111
pytest==6.2.*
12+
tensorflow_probability==0.12.*

src/ydata_synthetic/synthesizers/regular/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ydata_synthetic.synthesizers.regular.dragan.model import DRAGAN
66
from ydata_synthetic.synthesizers.regular.cramergan.model import CRAMERGAN
77
from ydata_synthetic.synthesizers.regular.cwgangp.model import CWGANGP
8+
from ydata_synthetic.synthesizers.regular.pategan.model import PATEGAN
89

910
__all__ = [
1011
"VanilllaGAN",
@@ -14,4 +15,5 @@
1415
"DRAGAN",
1516
"CRAMERGAN",
1617
"CWGANGP"
18+
"PATEGAN"
1719
]

src/ydata_synthetic/synthesizers/regular/pategan/__init__.py

Whitespace-only changes.
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
"PATEGAN implementation supporting Differential Privacy budget specification."
2+
# pylint: disable = W0622, E0401
3+
from math import log
4+
from typing import List, NamedTuple, Optional
5+
6+
import tqdm
7+
from tensorflow import (GradientTape, clip_by_value, concat, constant,
8+
expand_dims, ones_like, tensor_scatter_nd_update,
9+
transpose, zeros, zeros_like)
10+
from tensorflow.data import Dataset
11+
from tensorflow.dtypes import cast, float64, int64
12+
from tensorflow.keras import Model
13+
from tensorflow.keras.layers import Dense, Input, ReLU
14+
from tensorflow.keras.losses import BinaryCrossentropy
15+
from tensorflow.keras.optimizers import Adam
16+
from tensorflow.math import abs, exp, pow, reduce_sum, square
17+
from tensorflow.random import uniform
18+
from tensorflow_probability import distributions
19+
20+
from ydata_synthetic.synthesizers import TrainParameters
21+
from ydata_synthetic.synthesizers.gan import BaseModel
22+
from ydata_synthetic.utils.gumbel_softmax import ActivationInterface
23+
24+
25+
# pylint: disable=R0902
26+
class PATEGAN(BaseModel):
27+
"A basic PATEGAN synthesizer implementation with configurable differential privacy budget."
28+
29+
__MODEL__='PATEGAN'
30+
31+
def __init__(self, model_parameters, n_teachers: int, target_delta: float, target_epsilon: float):
32+
super().__init__(model_parameters)
33+
self.n_teachers = n_teachers
34+
self.target_epsilon = target_epsilon
35+
self.target_delta = target_delta
36+
37+
# pylint: disable=W0201
38+
def define_gan(self, processor_info: Optional[NamedTuple] = None):
39+
def discriminator():
40+
return Discriminator(self.batch_size).build_model((self.data_dim,), self.layers_dim)
41+
42+
self.generator = Generator(self.batch_size). \
43+
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim,
44+
processor_info=processor_info)
45+
self.s_discriminator = discriminator()
46+
self.t_discriminators = [discriminator() for i in range(self.n_teachers)]
47+
48+
generator_optimizer = Adam(learning_rate=self.g_lr)
49+
discriminator_optimizer = Adam(learning_rate=self.d_lr)
50+
51+
loss_fn = BinaryCrossentropy(from_logits=True)
52+
self.generator.compile(loss=loss_fn, optimizer=generator_optimizer)
53+
self.s_discriminator.compile(loss=loss_fn, optimizer=discriminator_optimizer)
54+
for teacher in self.t_discriminators:
55+
teacher.compile(loss=loss_fn, optimizer=discriminator_optimizer)
56+
57+
# pylint: disable = C0103
58+
@staticmethod
59+
def _moments_acc(n_teachers, votes, lap_scale, l_list):
60+
q = (2 + lap_scale * abs(2 * votes - n_teachers))/(4 * exp(lap_scale * abs(2 * votes - n_teachers)))
61+
62+
update = []
63+
for l in l_list:
64+
clip = 2 * square(lap_scale) * l * (l + 1)
65+
t = (1 - q) * pow((1 - q) / (1 - exp(2 * lap_scale) * q), l) + q * exp(2 * lap_scale * l)
66+
update.append(reduce_sum(clip_by_value(t, clip_value_min=-clip, clip_value_max=clip)))
67+
return cast(update, dtype=float64)
68+
69+
def get_data_loader(self, data) -> List[Dataset]:
70+
"Obtain a List of TF Datasets corresponding to partitions for each teacher in n_teachers."
71+
loader = []
72+
SHUFFLE_BUFFER_SIZE = 100
73+
74+
for teacher_id in range(self.n_teachers):
75+
start_id = int(teacher_id * len(data) / self.n_teachers)
76+
end_id = int((teacher_id + 1) * len(data) / self.n_teachers if \
77+
teacher_id != (self.n_teachers - 1) else len(data))
78+
loader.append(Dataset.from_tensor_slices(data[start_id:end_id:])\
79+
.batch(self.batch_size).shuffle(SHUFFLE_BUFFER_SIZE))
80+
return loader
81+
82+
# pylint:disable=R0913
83+
def train(self, data, class_ratios, train_arguments: TrainParameters, num_cols: List[str], cat_cols: List[str]):
84+
"""
85+
Args:
86+
data: A pandas DataFrame or a Numpy array with the data to be synthesized
87+
class_ratios:
88+
train_arguments: GAN training arguments.
89+
num_cols: List of columns of the data object to be handled as numerical
90+
cat_cols: List of columns of the data object to be handled as categorical
91+
"""
92+
super().train(data, num_cols, cat_cols)
93+
94+
data = self.processor.transform(data)
95+
self.data_dim = data.shape[1]
96+
self.define_gan(self.processor.col_transform_info)
97+
98+
self.class_ratios = class_ratios
99+
100+
alpha = cast([0.0 for _ in range(train_arguments.num_moments)], float64)
101+
l_list = 1 + cast(range(train_arguments.num_moments), float64)
102+
103+
# print("initial alpha", l_list.shape)
104+
105+
cross_entropy = BinaryCrossentropy(from_logits=True)
106+
107+
generator_optimizer = Adam(learning_rate=train_arguments.lr)
108+
disc_opt_stu = Adam(learning_rate=train_arguments.lr)
109+
disc_opt_t = [Adam(learning_rate=train_arguments.lr) for i in range(self.n_teachers)]
110+
111+
train_loader = self.get_data_loader(data, self.batch_size)
112+
113+
steps = 0
114+
epsilon = 0
115+
116+
category_samples = distributions.Categorical(probs=self.class_ratios, dtype=float64)
117+
118+
while epsilon < self.target_epsilon:
119+
# train the teacher descriminator
120+
for t_2 in range(train_arguments.num_teacher_iters):
121+
for i in range(self.n_teachers):
122+
inputs, categories = None, None
123+
for b, data_ in enumerate(train_loader[i]):
124+
inputs, categories = data_, b # categories = 0, data_ holds the first batch, why do we do this?
125+
#categories will give zero value in each loop as the loop break after running the first time
126+
#inputs will have only the first batch of data
127+
break
128+
129+
with GradientTape() as disc_tape:
130+
# train with real
131+
dis_data = concat([inputs, zeros((self.batch_size, 1), dtype=float64)], 1) # Why do we append a column of zeros instead of categories?
132+
# print("1st batch data", dis_data.shape)
133+
real_output = self.t_discriminators[i](dis_data, training=True)
134+
# print(real_output.shape, tf.ones.shape)
135+
136+
# train with fake
137+
z = uniform([self.batch_size, self.noise_dim], dtype=float64)
138+
# print("uniformly distributed noise", z.shape)
139+
140+
sample = expand_dims(category_samples.sample(self.batch_size), axis=1)
141+
# print("category", sample.shape)
142+
143+
fake = self.generator(concat([z, sample], 1))
144+
# print('fake', fake.shape)
145+
146+
fake_output = self.t_discriminators[i](concat([fake, sample], 1), training=True)
147+
# print('fake_output_dis', fake_output.shape)
148+
149+
# print("watch", disc_tape.watch(self.teacher_disc[i].trainable_variables)
150+
real_loss_disc = cross_entropy(ones_like(real_output), real_output)
151+
fake_loss_disc = cross_entropy(zeros_like(fake_output), fake_output)
152+
153+
disc_loss = real_loss_disc + fake_loss_disc
154+
# print(disc_loss, real_loss_disc, fake_loss_disc)
155+
156+
disc_grad = disc_tape.gradient(disc_loss, self.t_discriminators[i].trainable_variables)
157+
# print(gradients_of_discriminator)
158+
159+
disc_opt_t[i].apply_gradients(zip(disc_grad, self.t_discriminators[i].trainable_variables))
160+
161+
# train the student discriminator
162+
for t_3 in range(train_arguments.num_student_iters):
163+
z = uniform([self.batch_size, self.noise_dim], dtype=float64)
164+
165+
sample = expand_dims(category_samples.sample(self.batch_size), axis=1)
166+
# print("category_stu", sample.shape)
167+
168+
with GradientTape() as stu_tape:
169+
fake = self.generator(concat([z, sample], 1))
170+
# print('fake_stu', fake.shape)
171+
172+
predictions, clean_votes = self._pate_voting(
173+
concat([fake, sample], 1), self.t_discriminators, train_arguments.lap_scale)
174+
# print("noisy_labels", predictions.shape, "clean_votes", clean_votes.shape)
175+
outputs = self.s_discriminator(concat([fake, sample], 1))
176+
177+
# update the moments
178+
alpha = alpha + self._moments_acc(self.n_teachers, clean_votes, train_arguments.lap_scale, l_list)
179+
# print("final_alpha", alpha)
180+
181+
stu_loss = cross_entropy(predictions, outputs)
182+
gradients_of_stu = stu_tape.gradient(stu_loss, self.s_discriminator.trainable_variables)
183+
# print(gradients_of_stu)
184+
185+
disc_opt_stu.apply_gradients(zip(gradients_of_stu, self.s_discriminator.trainable_variables))
186+
187+
# train the generator
188+
z = uniform([self.batch_size, self.noise_dim], dtype=float64)
189+
190+
sample_g = expand_dims(category_samples.sample(self.batch_size), axis=1)
191+
192+
with GradientTape() as gen_tape:
193+
fake = self.generator(concat([z, sample_g], 1))
194+
output = self.s_discriminator(concat([fake, sample_g], 1))
195+
196+
loss_gen = cross_entropy(ones_like(output), output)
197+
gradients_of_generator = gen_tape.gradient(loss_gen, self.generator.trainable_variables)
198+
generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
199+
200+
# Calculate the current privacy cost
201+
epsilon = min((alpha - log(train_arguments.delta)) / l_list)
202+
if steps % train_arguments.sample_interval == 0:
203+
print("Step : ", steps, "Loss SD : ", stu_loss, "Loss G : ", loss_gen, "Epsilon : ", epsilon)
204+
205+
steps += 1
206+
# self.generator.summary()
207+
208+
def _pate_voting(self, data, netTD, lap_scale):
209+
# TODO: Validate the logic against original article
210+
## Faz os votos dos teachers (1/0) netTD para cada record em data e guarda em results
211+
results = zeros([len(netTD), self.batch_size], dtype=int64)
212+
# print(results)
213+
for i in range(len(netTD)):
214+
output = netTD[i](data, training=True)
215+
pred = transpose(cast((output > 0.5), int64))
216+
# print(pred)
217+
results = tensor_scatter_nd_update(results, constant([[i]]), pred)
218+
# print(results)
219+
220+
#guarda o somatorio das probabilidades atribuidas por cada disc a cada record (valores entre 0 e len(netTD))
221+
clean_votes = expand_dims(cast(reduce_sum(results, 0), dtype=float64), 1)
222+
# print("clean_votes",clean_votes)
223+
noise_sample = distributions.Laplace(loc=0, scale=1/lap_scale).sample(clean_votes.shape)
224+
# print("noise_sample", noise_sample)
225+
noisy_results = clean_votes + cast(noise_sample, float64)
226+
noisy_labels = cast((noisy_results > len(netTD)/2), float64)
227+
228+
return noisy_labels, clean_votes
229+
230+
231+
class Discriminator(Model):
232+
def __init__(self, batch_size):
233+
self.batch_size = batch_size
234+
235+
def build_model(self, input_shape, dim):
236+
input = Input(shape=input_shape, batch_size=self.batch_size)
237+
x = Dense(dim * 4)(input)
238+
x = ReLU()(x)
239+
x = Dense(dim * 2)(x)
240+
x = Dense(1)(x)
241+
return Model(inputs=input, outputs=x)
242+
243+
244+
class Generator(Model):
245+
def __init__(self, batch_size):
246+
self.batch_size = batch_size
247+
248+
def build_model(self, input_shape, dim, data_dim, processor_info: Optional[NamedTuple] = None):
249+
input = Input(shape=input_shape, batch_size = self.batch_size)
250+
x = Dense(dim)(input)
251+
x = ReLU()(x)
252+
x = Dense(dim * 2)(x)
253+
x = Dense(data_dim)(x)
254+
if processor_info:
255+
x = ActivationInterface(processor_info, 'ActivationInterface')(x)
256+
return Model(inputs=input, outputs=x)

0 commit comments

Comments
 (0)