ctc_base.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:theano_ctc 作者: mcf06 项目源码 文件源码
def make_node(self, acts, labels, input_lengths):
    # Unless specified, assume all sequences have full sequence length, i.e. acts_.shape[0]
    if input_lengths == None:
      input_lengths = T.cast(acts.shape[0], dtype="int32") * T.ones_like(acts[0,:,0], dtype=np.int32)

    # acts.shape = [seqLen, batchN, outputUnit]
    if acts.dtype != "float32":
      raise Exception("acts must be float32 instead of %s" % acts.dtype)
    # labels.shape = [batchN, labelLen]
    if labels.dtype != "int32":
      raise Exception("labels must be int32 instead of %s" % labels.dtype)
    # input_lengths.shape = [batchN]
    if input_lengths.dtype != "int32":
      raise Exception("input_lengths must be int32 instead of %s" % input_lengths.dtype)

    applyNode = theano.Apply(self, inputs=[acts, input_lengths, labels], outputs=[self.costs, self.gradients])

    # Return only the cost. Gradient will be returned by grad()
    self.default_output = 0 

    return applyNode
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号