Skip to content

Commit 849ab78

Browse files
authored
Update train.py
1 parent 67b3052 commit 849ab78

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def _map_fn_train(img):
9797
with tf.GradientTape() as tape:
9898
fake_hr_patchs = G(lr_patchs)
9999
mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
100-
grad = tape.gradient(mse_loss, G.weights)
101-
g_optimizer_init.apply_gradients(zip(grad, G.weights))
100+
grad = tape.gradient(mse_loss, G.trainable_weights)
101+
g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
102102
step += 1
103103
epoch = step//n_step_epoch
104104
print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(

0 commit comments

Comments
 (0)