trajmodel.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:RNN-TrajModel 作者: wuhao5688 项目源码 文件源码
def feed(self, batch):
    """
    feed one batch to placeholders by constructing the feed dict
    :param batch: a Batch object
    :return: feed dict of inputs
    """
    input_feed = {}
    input_feed[self.inputs_.name] = batch.inputs
    input_feed[self.targets_.name] = batch.targets
    input_feed[self.mask_.name] = batch.masks
    input_feed[self.dests_label_.name] = batch.dests
    input_feed[self.seq_len_.name] = batch.seq_lens
    if self.logits_mask__ is not None:
      values = np.ones(len(batch.adj_indices), np.float32)
      shape = np.array([np.size(batch.inputs), self.config.state_size], dtype=np.int32)
      input_feed[self.logits_mask__] = tf.SparseTensorValue(batch.adj_indices, values, shape)
    input_feed[self.lr_] = self.config.lr
    if self.sub_onehot_targets_ is not None:
      input_feed[self.sub_onehot_targets_] = batch.sub_onehot_target
    return input_feed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号