-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathnetwork.py
More file actions
114 lines (90 loc) · 3.53 KB
/
Copy pathnetwork.py
File metadata and controls
114 lines (90 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
import tensorflow as tf
import numpy as np
#----------------------------------------------------------------------------
# Get/create weight tensor for a convolutional or fully-connected layer.
def get_weight(shape, gain=np.sqrt(2)):
fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
std = gain / np.sqrt(fan_in) # He init
w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std))
return w
#----------------------------------------------------------------------------
# Convolutional layer.
def apply_bias(x):
b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros())
b = tf.cast(b, x.dtype)
if len(x.shape) == 2:
return x + b
return x + tf.reshape(b, [1, -1, 1, 1])
def conv2d_bias(x, fmaps, kernel, gain=np.sqrt(2)):
assert kernel >= 1 and kernel % 2 == 1
w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain)
w = tf.cast(w, x.dtype)
return apply_bias(tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW'))
def maxpool2d(x, k=2):
ksize = [1, 1, k, k]
return tf.nn.max_pool(x, ksize=ksize, strides=ksize, padding='SAME', data_format='NCHW')
# TODO use fused upscale+conv2d from gan2
def upscale2d(x, factor=2):
assert isinstance(factor, int) and factor >= 1
if factor == 1: return x
with tf.variable_scope('Upscale2D'):
s = x.shape
x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
x = tf.tile(x, [1, 1, 1, factor, 1, factor])
x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
return x
def conv_lr(name, x, fmaps):
with tf.variable_scope(name):
return tf.nn.leaky_relu(conv2d_bias(x, fmaps, 3), alpha=0.1)
def conv(name, x, fmaps, gain):
with tf.variable_scope(name):
return conv2d_bias(x, fmaps, 3, gain)
def autoencoder(x, width=256, height=256, **_kwargs):
x.set_shape([None, 3, height, width])
skips = [x]
n = x
n = conv_lr('enc_conv0', n, 48)
n = conv_lr('enc_conv1', n, 48)
n = maxpool2d(n)
skips.append(n)
n = conv_lr('enc_conv2', n, 48)
n = maxpool2d(n)
skips.append(n)
n = conv_lr('enc_conv3', n, 48)
n = maxpool2d(n)
skips.append(n)
n = conv_lr('enc_conv4', n, 48)
n = maxpool2d(n)
skips.append(n)
n = conv_lr('enc_conv5', n, 48)
n = maxpool2d(n)
n = conv_lr('enc_conv6', n, 48)
#-----------------------------------------------
n = upscale2d(n)
n = tf.concat([n, skips.pop()], axis=1)
n = conv_lr('dec_conv5', n, 96)
n = conv_lr('dec_conv5b', n, 96)
n = upscale2d(n)
n = tf.concat([n, skips.pop()], axis=1)
n = conv_lr('dec_conv4', n, 96)
n = conv_lr('dec_conv4b', n, 96)
n = upscale2d(n)
n = tf.concat([n, skips.pop()], axis=1)
n = conv_lr('dec_conv3', n, 96)
n = conv_lr('dec_conv3b', n, 96)
n = upscale2d(n)
n = tf.concat([n, skips.pop()], axis=1)
n = conv_lr('dec_conv2', n, 96)
n = conv_lr('dec_conv2b', n, 96)
n = upscale2d(n)
n = tf.concat([n, skips.pop()], axis=1)
n = conv_lr('dec_conv1a', n, 64)
n = conv_lr('dec_conv1b', n, 32)
n = conv('dec_conv1', n, 3, gain=1.0)
return n