layers.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:text-gan-tensorflow 作者: tokestermw 项目源码 文件源码
def recurrent_layer(tensor, cell=None, hidden_dims=128, sequence_length=None, decoder_fn=None, 
                    activation=tf.nn.tanh, initializer=tf.orthogonal_initializer(), initial_state=None, 
                    keep_prob=1.0,
                    return_final_state=False, return_next_cell_input=True, **opts):
    if cell is None:
        cell = tf.contrib.rnn.BasicRNNCell(hidden_dims, activation=activation)
        # cell = tf.contrib.rnn.LSTMCell(hidden_dims, activation=activation)

    if keep_prob < 1.0:
        keep_prob = _global_keep_prob(keep_prob)
        cell = tf.contrib.rnn.DropoutWrapper(cell, keep_prob, keep_prob)

    if opts.get("name"):
        tf.add_to_collection(opts.get("name"), cell)

    if decoder_fn is None:
        outputs, final_state = tf.nn.dynamic_rnn(cell, tensor, 
            sequence_length=sequence_length, initial_state=initial_state, dtype=tf.float32)
        final_context_state = None
    else:
        # TODO: turn off sequence_length?
        outputs, final_state, final_context_state = seq2seq.dynamic_rnn_decoder(
            cell, decoder_fn, inputs=None, sequence_length=sequence_length)

    if return_final_state:
        return final_state
    else:
        return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号