train_nlm.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Neural-Language-Model 作者: robosoup 项目源码 文件源码
def __init__(self, cfg, data, name):
        self.steps = ((len(data) // cfg.batch_size) - 1) // cfg.num_steps
        with tf.name_scope(name, values=[data, cfg.batch_size, cfg.num_steps]):
            raw_data = tf.convert_to_tensor(data)
            data_len = tf.size(raw_data)
            batch_len = data_len // cfg.batch_size
            data = tf.reshape(raw_data[0: cfg.batch_size * batch_len], [cfg.batch_size, batch_len])
            epoch_size = (batch_len - 1) // cfg.num_steps
            epoch_size = tf.identity(epoch_size, name="epoch_size")
            i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

            begin_x = [0, i * cfg.num_steps]
            self.inputs = tf.strided_slice(
                data, begin_x, [cfg.batch_size, (i + 1) * cfg.num_steps], tf.ones_like(begin_x))
            self.inputs.set_shape([cfg.batch_size, cfg.num_steps])

            begin_y = [0, i * cfg.num_steps + 1]
            self.targets = tf.strided_slice(
                data, begin_y, [cfg.batch_size, (i + 1) * cfg.num_steps + 1], tf.ones_like(begin_y))
            self.targets.set_shape([cfg.batch_size, cfg.num_steps])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号