Skip to content

Commit 4106032

Browse files
committed
Merge branch 'master' of https://github.com/tensorlayer/srgan
2 parents 66f0a5c + 2fd80d0 commit 4106032

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ We run this script under [TensorFlow](https://www.tensorflow.org) 2.0 and the [T
88
<!---
99
⚠️ This repo will be merged into example folder of [tensorlayer](https://github.com/zsdonghao/tensorlayer) soon.
1010
-->
11-
🚀🚀🚀🚀🚀🚀 **This repo will be MOVED to [this folder](https://github.com/tensorlayer/tensorlayer/tree/master/examples) in next month.**
11+
🚀🚀🚀🚀🚀🚀 **THIS PROJECT WILL BE CLOSED AND MOVED TO [THIS FOLDER](https://github.com/tensorlayer/tensorlayer/tree/master/examples) IN NEXT MONTH.**
12+
13+
🚀🚀🚀🚀🚀🚀 **THIS PROJECT WILL BE CLOSED AND MOVED TO [THIS FOLDER](https://github.com/tensorlayer/tensorlayer/tree/master/examples) IN NEXT MONTH.**
14+
15+
🚀🚀🚀🚀🚀🚀 **THIS PROJECT WILL BE CLOSED AND MOVED TO [THIS FOLDER](https://github.com/tensorlayer/tensorlayer/tree/master/examples) IN NEXT MONTH.**
1216

1317
<!--More cool Computer Vision applications such as pose estimation and style transfer can be found in this [organization](https://github.com/tensorlayer).**
1418
-->
@@ -59,13 +63,15 @@ config.TRAIN.img_path = "your_image_folder/"
5963
- Start training.
6064

6165
```bash
62-
python main.py
66+
python train.py
6367
```
6468

65-
- Start evaluation. ([pretrained model](https://github.com/tensorlayer/srgan/releases/tag/1.2.0) for DIV2K)
69+
- Start evaluation.
70+
71+
<!--([pretrained model](https://github.com/tensorlayer/srgan/releases/tag/1.2.0) for DIV2K)-->
6672

6773
```bash
68-
python main.py --mode=evaluate
74+
python train.py --mode=evaluate
6975
```
7076

7177

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
tensorlayer>=2.0.0
2+
tensorflow>=2.0.0

train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#! /usr/bin/python
22
# -*- coding: utf8 -*-
33

4-
import time, random
4+
import time
5+
import random
56
import numpy as np
67
import scipy, multiprocessing
78
import tensorflow as tf
@@ -65,7 +66,7 @@ def _map_fn_train(img):
6566
train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
6667
train_ds = train_ds.repeat(n_epoch_init + n_epoch)
6768
train_ds = train_ds.shuffle(shuffle_buffer_size)
68-
train_ds = train_ds.prefetch(buffer_size=4096)
69+
train_ds = train_ds.prefetch(buffer_size=2)
6970
train_ds = train_ds.batch(batch_size)
7071
# value = train_ds.make_one_shot_iterator().get_next()
7172

@@ -97,8 +98,8 @@ def _map_fn_train(img):
9798
with tf.GradientTape() as tape:
9899
fake_hr_patchs = G(lr_patchs)
99100
mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
100-
grad = tape.gradient(mse_loss, G.weights)
101-
g_optimizer_init.apply_gradients(zip(grad, G.weights))
101+
grad = tape.gradient(mse_loss, G.trainable_weights)
102+
g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
102103
step += 1
103104
epoch = step//n_step_epoch
104105
print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(
@@ -108,7 +109,7 @@ def _map_fn_train(img):
108109

109110
# adversarial learning (G, D)
110111
n_step_epoch = round(n_epoch // batch_size)
111-
for step, (lr_patchs, hr_patchs) in train_ds:
112+
for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
112113
with tf.GradientTape(persistent=True) as tape:
113114
fake_patchs = G(lr_patchs)
114115
logits_fake = D(fake_patchs)
@@ -124,7 +125,7 @@ def _map_fn_train(img):
124125
g_loss = mse_loss + vgg_loss + g_gan_loss
125126
grad = tape.gradient(g_loss, G.trainable_weights)
126127
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
127-
grad = tape.gradient(d_loss, D.weights)
128+
grad = tape.gradient(d_loss, D.trainable_weights)
128129
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
129130
step += 1
130131
epoch = step//n_step_epoch
@@ -329,6 +330,8 @@ def evaluate():
329330
# G.load_weights(checkpoint_dir + '/g_{}.h5'.format(tl.global_flag['mode']))
330331
G.load_weights("g_srgan.npz")
331332
G.eval()
333+
334+
valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
332335

333336
out = G(valid_lr_img).numpy()
334337

0 commit comments

Comments
 (0)