def __init__(self, batch_size, mem_size, hidden_size):
self.hidden_size = hidden_size
self.mem_size = mem_size
self.batch_size = batch_size
N, M, d = batch_size, mem_size, hidden_size
self.L = np.tril(np.ones([M, M], dtype='float32'))
self.sL = np.tril(np.ones([M, M], dtype='float32'), k=-1)
评论列表
文章目录