seq_batch.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def embed(sequence_batch, embeds):
    mask = sequence_batch.mask
    embedded_values = tf.gather(embeds, sequence_batch.values)
    embedded_values = tf.verify_tensor_all_finite(embedded_values, 'embedded_values')

    # set all pad embeddings to zero
    broadcasted_mask = expand_dims_for_broadcast(mask, embedded_values)
    embedded_values *= broadcasted_mask

    return SequenceBatch(embedded_values, mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号