layer.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:relaax 作者: deeplearninc 项目源码 文件源码
def build_graph(self, x, batch_size=1, n_units=256):
        self.phs = [graph.Placeholder(np.float32, [batch_size, n_units]) for _ in range(2)]
        self.ph_state = graph.TfNode(tuple(ph.node for ph in self.phs))
        self.ph_state.checked = tuple(ph.checked for ph in self.phs)

        self.zero_state = tuple(np.zeros([batch_size, n_units]) for _ in range(2))

        state = tf.contrib.rnn.LSTMStateTuple(*self.ph_state.checked)

        lstm = tf.contrib.rnn.BasicLSTMCell(n_units, state_is_tuple=True)

        outputs, self.state = tf.nn.dynamic_rnn(lstm, x.node, initial_state=state,
                                                sequence_length=tf.shape(x.node)[1:2], time_major=False)

        self.state = graph.TfNode(self.state)
        self.weight = graph.TfNode(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                     tf.get_variable_scope().name))
        return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号