ctc_decoder.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:neuralmonkey 作者: ufal 项目源码 文件源码
def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict:
        fd = {}  # type: FeedDict

        sentences = cast(Iterable[List[str]],
                         dataset.get_series(self.data_id, allow_none=True))

        fd[self.train_mode] = train

        if sentences is not None:
            vectors, paddings = self.vocabulary.sentences_to_tensor(
                list(sentences), train_mode=train)

            # sentences_to_tensor returns time-major tensors, targets need to
            # be batch-major
            vectors = vectors.T
            paddings = paddings.T

            # Need to convert the data to a sparse representation
            bool_mask = (paddings > 0.5)
            indices = np.stack(np.where(bool_mask), axis=1)
            values = vectors[bool_mask]

            fd[self.train_targets] = tf.SparseTensorValue(
                indices=indices, values=values,
                dense_shape=vectors.shape)

        return fd
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号