def __init__(self, question_size, passage_size, hidden_size, attn_size=None,
cell_type=nn.GRUCell, num_layers=1, dropout=0, residual=False, **kwargs):
super().__init__()
self.num_layers = num_layers
if attn_size is None:
attn_size = question_size
# TODO: what is V_q? (section 3.4)
v_q_size = question_size
self.question_pooling = AttentionPooling(question_size,
v_q_size, attn_size=attn_size)
self.passage_pooling = AttentionPooling(passage_size,
question_size, attn_size=attn_size)
self.V_q = nn.Parameter(torch.randn(1, 1, v_q_size), requires_grad=True)
self.cell = StackedCell(question_size, question_size, num_layers=num_layers,
dropout=dropout, rnn_cell=cell_type, residual=residual, **kwargs)
评论列表
文章目录