def forward(self, x, lengths):
batch_size = len(x)
lengths = [len(s) for s in x]
outputs = [Variable(torch.zeros(1, self.model_dim).float(), volatile=not self.training)
for _ in range(batch_size)]
for t in range(max(lengths)):
batch = []
h = []
idx = []
for i, (s, l) in enumerate(zip(x, lengths)):
if l >= max(lengths) - t:
batch.append(s.pop())
h.append(outputs[i])
idx.append(i)
batch = np.concatenate(np.array(batch).reshape(-1, 1), 0)
emb = Variable(torch.from_numpy(self.initial_embeddings.take(batch, 0)), volatile=not self.training)
h = torch.cat(h, 0)
h_next = self.rnn(emb, h)
h_next = torch.chunk(h_next, len(idx))
for i, o in zip(idx, h_next):
outputs[i] = o
outputs = torch.cat(outputs, 0)
h = F.relu(self.l0(F.dropout(outputs, 0.5, self.training)))
h = F.relu(self.l1(F.dropout(h, 0.5, self.training)))
y = F.log_softmax(h)
return y
dynamic.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录