utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码
def batch_producer(enc, dec, batch_size, name=None):
    data_len   = enc.shape[0]
    seq_len    = enc.shape[1]
    epoch_size = data_len // batch_size

    print("epoch size: %d " % epoch_size)

    with tf.name_scope(name, "batch", [enc, dec, batch_size]):
        enc = tf.convert_to_tensor(enc, name="enc", dtype=tf.float32)
        dec = tf.convert_to_tensor(dec, name="dec", dtype=tf.int32) 

        # generator 
        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

        x = tf.strided_slice(enc, [0, 0, 0],
                             [batch_size, seq_len, 2],
                             [1, 1, 1])
        x.set_shape([batch_size, seq_len, 2 ])

        y = tf.strided_slice(dec, [0, 0],
                             [batch_size, seq_len],
                             [1, 1])

        y.set_shape([batch_size, seq_len])

        return x, y

# for test

#if __name__ == "__main__":
#    enc_in, dec_out = _load_data("./convex_hull_50_train.txt")
#    print(enc_in.shape)
#    print(dec_out.shape)
#    #print(enc_in)
#    x_batch, y_batch = batch_producer(enc_in, dec_out, batch_size=20)

#    with tf.Session() as sess:
#        coord = tf.train.Coordinator()
#        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

#        print(sess.run([x_batch, y_batch]))

#        coord.request_stop()
#        coord.join(threads)


# ====================
# visualization
# ====================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号