rnn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:hyperchamber 作者: 255BITS 项目源码 文件源码
def create_rnn(config, x, scope='rnn'):
    with tf.variable_scope(scope):
        memory=config['rnn_size']
        cell = rnn_cell.BasicLSTMCell(memory)
        state = cell.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
        x, state = rnn.rnn(cell, [tf.cast(x,tf.float32)], initial_state=state, dtype=tf.float32)
        x = x[-1]
        #w = tf.get_variable('w', [hc.get('rnn_size'),4])
        #b = tf.get_variable('b', [4])
        #x = tf.nn.xw_plus_b(x, w, b)
        x=tf.sign(x)
        return x, state

# Each step of the graph we have:
# x is [BATCH_SIZE, 4] where the data is an one hot binary vector of the form:
# [start_token end_token a b]
#
# y is [BATCH_SIZE, 4] is a binary vector of the chance each character is correct
#
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号