def ae(x):
if nonlinearity_name == 'relu':
f = tf.nn.relu
elif nonlinearity_name == 'elu':
f = tf.nn.elu
elif nonlinearity_name == 'gelu':
# def gelu(x):
# return tf.mul(x, tf.erfc(-x / tf.sqrt(2.)) / 2.)
# f = gelu
def gelu_fast(_x):
return 0.5 * _x * (1 + tf.tanh(tf.sqrt(2 / np.pi) * (_x + 0.044715 * tf.pow(_x, 3))))
f = gelu_fast
elif nonlinearity_name == 'silu':
def silu(_x):
return _x * tf.sigmoid(_x)
f = silu
# elif nonlinearity_name == 'soi':
# def soi_map(x):
# u = tf.random_uniform(tf.shape(x))
# mask = tf.to_float(tf.less(u, (1 + tf.erf(x / tf.sqrt(2.))) / 2.))
# return tf.cond(is_training, lambda: tf.mul(mask, x),
# lambda: tf.mul(x, tf.erfc(-x / tf.sqrt(2.)) / 2.))
# f = soi_map
else:
raise NameError("Need 'relu', 'elu', 'gelu', or 'silu' for nonlinearity_name")
h1 = f(tf.matmul(x, W['1']) + b['1'])
h2 = f(tf.matmul(h1, W['2']) + b['2'])
h3 = f(tf.matmul(h2, W['3']) + b['3'])
h4 = f(tf.matmul(h3, W['4']) + b['4'])
h5 = f(tf.matmul(h4, W['5']) + b['5'])
h6 = f(tf.matmul(h5, W['6']) + b['6'])
h7 = f(tf.matmul(h6, W['7']) + b['7'])
return tf.matmul(h7, W['8']) + b['8']
评论列表
文章目录