def extend_middle_dim(_2D, num):
"""
Gets a 2D tensor (A, B), outputs a 3D tensor (A, num, B)
:usage:
>>> TODO
"""
rval = _2D.dimshuffle((0, 'x', 1))
rval = T.alloc(rval, rval.shape[0], num, rval.shape[2])
return rval
评论列表
文章目录