Skip to content

Commit a1d2c2b

Browse files
committed
[release] 1.2.0 for TL 1.7.5
1 parent 085d477 commit a1d2c2b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+10435
-3314
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
We run this script under [TensorFlow](https://www.tensorflow.org) 1.2 and the self-contained [TensorLayer](http://tensorlayer.readthedocs.io/en/latest/). If you got error, you may need to update TensorLayer.
44

5+
<!---
56
⚠️ This repo will be merged into [tensorlayer](https://github.com/zsdonghao/tensorlayer) soon.
7+
-->
68

79
### SRGAN Architecture
810

main.py

Lines changed: 54 additions & 62 deletions
Large diffs are not rendered by default.

model.py

Lines changed: 70 additions & 115 deletions
Large diffs are not rendered by default.

tensorlayer/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
"""
2-
Deep learning and Reinforcement learning library for Researchers and Engineers
3-
"""
1+
"""Deep learning and Reinforcement learning library for Researchers and Engineers"""
42
from __future__ import absolute_import
53

6-
74
try:
85
install_instr = "Please make sure you install a recent enough version of TensorFlow."
96
import tensorflow
@@ -15,7 +12,6 @@
1512
from . import files
1613
from . import iterate
1714
from . import layers
18-
from . import ops
1915
from . import utils
2016
from . import visualize
2117
from . import prepro
@@ -27,7 +23,7 @@
2723
act = activation
2824
vis = visualize
2925

30-
__version__ = "1.7.3"
26+
__version__ = "1.7.4"
3127

3228
global_flag = {}
3329
global_dict = {}

tensorlayer/_logging.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import logging
2+
3+
logging.basicConfig(level=logging.INFO, format='[TL] %(message)s')
4+
5+
6+
def info(fmt, *args):
7+
logging.info(fmt, *args)

tensorlayer/activation.py

Lines changed: 80 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,120 +3,139 @@
33

44
import tensorflow as tf
55

6-
def identity(x, name=None):
7-
"""The identity activation function, Shortcut is ``linear``.
6+
7+
def identity(x):
8+
"""The identity activation function.
9+
Shortcut is ``linear``.
810
911
Parameters
1012
----------
11-
x : a tensor input
12-
input(s)
13+
x : Tensor
14+
input.
1315
1416
Returns
15-
--------
16-
A `Tensor` with the same type as `x`.
17+
-------
18+
Tensor
19+
A ``Tensor`` in the same type as ``x``.
20+
1721
"""
1822
return x
1923

20-
# Shortcut
21-
linear = identity
2224

23-
def ramp(x=None, v_min=0, v_max=1, name=None):
25+
def ramp(x, v_min=0, v_max=1, name=None):
2426
"""The ramp activation function.
2527
2628
Parameters
2729
----------
28-
x : a tensor input
29-
input(s)
30+
x : Tensor
31+
input.
3032
v_min : float
31-
if input(s) smaller than v_min, change inputs to v_min
33+
cap input to v_min as a lower bound.
3234
v_max : float
33-
if input(s) greater than v_max, change inputs to v_max
34-
name : a string or None
35-
An optional name to attach to this activation function.
35+
cap input to v_max as a upper bound.
36+
name : str
37+
The function name (optional).
3638
3739
Returns
38-
--------
39-
A `Tensor` with the same type as `x`.
40+
-------
41+
Tensor
42+
A ``Tensor`` in the same type as ``x``.
43+
4044
"""
4145
return tf.clip_by_value(x, clip_value_min=v_min, clip_value_max=v_max, name=name)
4246

43-
def leaky_relu(x=None, alpha=0.1, name="lrelu"):
47+
48+
def leaky_relu(x, alpha=0.1, name="lrelu"):
4449
"""The LeakyReLU, Shortcut is ``lrelu``.
4550
46-
Modified version of ReLU, introducing a nonzero gradient for negative
47-
input.
51+
Modified version of ReLU, introducing a nonzero gradient for negative input.
4852
4953
Parameters
5054
----------
51-
x : A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
52-
`int16`, or `int8`.
53-
alpha : `float`. slope.
54-
name : a string or None
55-
An optional name to attach to this activation function.
55+
x : Tensor
56+
Support input type ``float``, ``double``, ``int32``, ``int64``, ``uint8``,
57+
``int16``, or ``int8``.
58+
alpha : float
59+
Slope.
60+
name : str
61+
The function name (optional).
5662
5763
Examples
58-
---------
59-
>>> network = tl.layers.DenseLayer(network, n_units=100, name = 'dense_lrelu',
60-
... act= lambda x : tl.act.lrelu(x, 0.2))
64+
--------
65+
>>> net = tl.layers.DenseLayer(net, 100, act=lambda x : tl.act.lrelu(x, 0.2), name='dense')
66+
67+
Returns
68+
-------
69+
Tensor
70+
A ``Tensor`` in the same type as ``x``.
6171
6272
References
6373
------------
64-
- `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) <http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf>`_
74+
- `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) <http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf>`__
75+
6576
"""
6677
# with tf.name_scope(name) as scope:
67-
# x = tf.nn.relu(x)
68-
# m_x = tf.nn.relu(-x)
69-
# x -= alpha * m_x
78+
# x = tf.nn.relu(x)
79+
# m_x = tf.nn.relu(-x)
80+
# x -= alpha * m_x
7081
x = tf.maximum(x, alpha * x, name=name)
7182
return x
7283

73-
#Shortcut
74-
lrelu = leaky_relu
75-
7684

7785
def swish(x, name='swish'):
78-
"""The Swish function, see `Swish: a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941>`_.
86+
"""The Swish function.
87+
See `Swish: a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941>`__.
7988
8089
Parameters
8190
----------
82-
x : a tensor input
83-
input(s)
91+
x : Tensor
92+
input.
93+
name: str
94+
function name (optional).
8495
8596
Returns
86-
--------
87-
A `Tensor` with the same type as `x`.
97+
-------
98+
Tensor
99+
A ``Tensor`` in the same type as ``x``.
100+
88101
"""
89-
with tf.name_scope(name) as scope:
90-
x = tf.nn.sigmoid(x) * x
102+
with tf.name_scope(name):
103+
x = tf.nn.sigmoid(x) * x
91104
return x
92105

93-
def pixel_wise_softmax(output, name='pixel_wise_softmax'):
106+
107+
def pixel_wise_softmax(x, name='pixel_wise_softmax'):
94108
"""Return the softmax outputs of images, every pixels have multiple label, the sum of a pixel is 1.
95109
Usually be used for image segmentation.
96110
97111
Parameters
98-
------------
99-
output : tensor
100-
- For 2d image, 4D tensor [batch_size, height, weight, channel], channel >= 2.
101-
- For 3d image, 5D tensor [batch_size, depth, height, weight, channel], channel >= 2.
112+
----------
113+
x : Tensor
114+
input.
115+
- For 2d image, 4D tensor (batch_size, height, weight, channel), where channel >= 2.
116+
- For 3d image, 5D tensor (batch_size, depth, height, weight, channel), where channel >= 2.
117+
name : str
118+
function name (optional)
119+
120+
Returns
121+
-------
122+
Tensor
123+
A ``Tensor`` in the same type as ``x``.
102124
103125
Examples
104-
---------
126+
--------
105127
>>> outputs = pixel_wise_softmax(network.outputs)
106128
>>> dice_loss = 1 - dice_coe(outputs, y_, epsilon=1e-5)
107129
108130
References
109-
-----------
110-
- `tf.reverse <https://www.tensorflow.org/versions/master/api_docs/python/array_ops.html#reverse>`_
131+
----------
132+
- `tf.reverse <https://www.tensorflow.org/versions/master/api_docs/python/array_ops.html#reverse>`__
133+
111134
"""
112-
with tf.name_scope(name) as scope:
113-
return tf.nn.softmax(output)
114-
## old implementation
115-
# exp_map = tf.exp(output)
116-
# if output.get_shape().ndims == 4: # 2d image
117-
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True]))
118-
# elif output.get_shape().ndims == 5: # 3d image
119-
# evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True]))
120-
# else:
121-
# raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape))
122-
# return tf.div(exp_map, evidence)
135+
with tf.name_scope(name):
136+
return tf.nn.softmax(x)
137+
138+
139+
# Alias
140+
linear = identity
141+
lrelu = leaky_relu

tensorlayer/cli/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""The tensorlayer.cli module provides a command-line tool for some common tasks."""

tensorlayer/cli/__main__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import argparse
2+
from tensorlayer.cli import train
3+
4+
if __name__ == "__main__":
5+
parser = argparse.ArgumentParser(prog='tl')
6+
subparsers = parser.add_subparsers(dest='cmd')
7+
train_parser = subparsers.add_parser('train', help='train a model using multiple local GPUs or CPUs.')
8+
train.build_arg_parser(train_parser)
9+
args = parser.parse_args()
10+
if args.cmd == 'train':
11+
train.main(args)
12+
else:
13+
parser.print_help()

0 commit comments

Comments
 (0)