utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def ptb_producer(raw_data, batch_size, num_steps, name=None):
    with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
        raw_data  = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
        data_len  = tf.size(raw_data)
        batch_len = data_len // batch_size
        data      = tf.reshape(raw_data[0 : batch_size * batch_len],
                               [batch_size, batch_len])

        epoch_size = (batch_len - 1) // num_steps
        epoch_size = tf.identity(epoch_size, name="epoch_size")

        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

        x = tf.strided_slice(data, [0, i * num_steps],
                             [batch_size, (i + 1) * num_steps],
                             #tf.ones_like([0, i * num_steps]))
                             [1,1])
        x.set_shape([batch_size, num_steps])
        y = tf.strided_slice(data, [0, i * num_steps + 1],
                             [batch_size, (i + 1) * num_steps + 1],
                             #tf.ones_like([0, i * num_steps]))
                             [1,1])
        y.set_shape([batch_size, num_steps])
        return x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号