11#! /usr/bin/python
22# -*- coding: utf8 -*-
33
4-
54import os , time , pickle , random , time
65from datetime import datetime
76import numpy as np
1413from utils import *
1514from config import config , log_config
1615
17-
18-
19-
2016###====================== HYPER-PARAMETERS ===========================###
2117## Adam
2218batch_size = config .TRAIN .batch_size
@@ -87,8 +83,6 @@ def train():
8783
8884 net_vgg , vgg_target_emb = Vgg19_simple_api ((t_target_image_224 + 1 )/ 2 , reuse = False )
8985 _ , vgg_predict_emb = Vgg19_simple_api ((t_predict_image_224 + 1 )/ 2 , reuse = True )
90- # print(vgg_predict_emb.outputs)
91- # # exit()
9286
9387 ## test inference
9488 net_g_test = SRGAN_g (t_image , is_train = False , reuse = True )
@@ -98,31 +92,11 @@ def train():
9892 d_loss2 = tl .cost .sigmoid_cross_entropy (logits_fake , tf .zeros_like (logits_fake ), name = 'd2' )
9993 d_loss = d_loss1 + d_loss2
10094
101- # g_gan_loss = 1e-1 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
102- # mse_loss = normalize_mean_squared_error(net_g.outputs, t_target_image) # simiao
103- # vgg_loss = 5e-1 * normalize_mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs)
104-
105- g_gan_loss = 1e-3 * tl .cost .sigmoid_cross_entropy (logits_fake , tf .ones_like (logits_fake ), name = 'g' ) # paper 1e-3
106- mse_loss = tl .cost .mean_squared_error (net_g .outputs , t_target_image , is_mean = True ) # paper
107- vgg_loss = 2e-6 * tl .cost .mean_squared_error (vgg_predict_emb .outputs , vgg_target_emb .outputs , is_mean = True ) # simiao
108-
109- ## simiao
110- # g_gan_loss = tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
111- # mse_loss = normalize_mean_squared_error(net_g.outputs, t_target_image)
112- # vgg_loss = 0.00025 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)
113-
114- ## history
115- # resize-conv MSE + 1e-2*g_gan_loss: 1020 green broken, but can recover/ 1030 always green
116- # resize-conv MSE + 1e-3*g_gan_loss: more stable than 1e-2, 1043 bubble
117- # resize-conv MSE + 1e-3*g_gan_loss +1e-6*VGG 相比 mse+gan, bubble少了很多,d loss ≈ 0.5 (G not powerful?)
118- # subpixel-conv MSE + 1e-3*g_gan_loss +1e-6*VGG (no pretrain), small checkboard. VGG loss ≈ MSE / 2
119- # train higher VGG loss?
120- # subpixel-conv MSE + 1e-3*g_gan_loss +2e-6*VGG (no pretrain), small checkboard. VGG loss ≈ MSE
121- # subpixel-conv MSE + 1e-4*g_gan_loss +2e-6*VGG (no pretrain), small checkboard. 50epoch d loss very small ≈ 0.02054373
122- # subpixel-conv MSE + 1e-3*g_gan_loss +2e-6*VGG, 100 epoch pretrain, bare checkboard!
95+ g_gan_loss = 1e-3 * tl .cost .sigmoid_cross_entropy (logits_fake , tf .ones_like (logits_fake ), name = 'g' )
96+ mse_loss = tl .cost .mean_squared_error (net_g .outputs , t_target_image , is_mean = True )
97+ vgg_loss = 2e-6 * tl .cost .mean_squared_error (vgg_predict_emb .outputs , vgg_target_emb .outputs , is_mean = True )
12398
12499 g_loss = mse_loss + vgg_loss + g_gan_loss
125- # g_loss = mse_loss + g_gan_loss
126100
127101 g_vars = tl .layers .get_variables_with_name ('SRGAN_g' , True , True )
128102 d_vars = tl .layers .get_variables_with_name ('SRGAN_d' , True , True )
@@ -305,6 +279,7 @@ def evaluate():
305279
306280 size = valid_lr_img .shape
307281 t_image = tf .placeholder ('float32' , [None , size [0 ], size [1 ], size [2 ]], name = 'input_image' )
282+ # t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')
308283
309284 net_g = SRGAN_g (t_image , is_train = False , reuse = False )
310285
0 commit comments