We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 67b3052 commit 849ab78Copy full SHA for 849ab78
train.py
@@ -97,8 +97,8 @@ def _map_fn_train(img):
97
with tf.GradientTape() as tape:
98
fake_hr_patchs = G(lr_patchs)
99
mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
100
- grad = tape.gradient(mse_loss, G.weights)
101
- g_optimizer_init.apply_gradients(zip(grad, G.weights))
+ grad = tape.gradient(mse_loss, G.trainable_weights)
+ g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
102
step += 1
103
epoch = step//n_step_epoch
104
print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(
0 commit comments