def highway(self, input_1, input_2, size_1, size_2, l2_penalty=1e-8, layer_size=1):
output = input_2
for idx in range(layer_size):
with tf.name_scope('output_lin_%d' % idx):
W = tf.Variable(tf.truncated_normal([size_2,size_1], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[size_1]), name="b")
tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(W))
tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(b))
output = tf.nn.relu(tf.nn.xw_plus_b(output,W,b))
with tf.name_scope('transform_lin_%d' % idx):
W = tf.Variable(tf.truncated_normal([size_1,size_1], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0.1, shape=[size_1]), name="b")
tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(W))
tf.add_to_collection(name=tf.GraphKeys.REGULARIZATION_LOSSES, value=l2_penalty*tf.nn.l2_loss(b))
transform_gate = tf.sigmoid(tf.nn.xw_plus_b(input_1,W,b))
carry_gate = tf.constant(1.0) - transform_gate
output = transform_gate * output + carry_gate * input_1
return output
评论列表
文章目录