def __init__(self, shared_resources: SharedResources):
super(FastQAPyTorchModule, self).__init__()
self._shared_resources = shared_resources
input_size = shared_resources.config["repr_dim_input"]
size = shared_resources.config["repr_dim"]
self._size = size
self._with_char_embeddings = self._shared_resources.config.get("with_char_embeddings", False)
# modules & parameters
if self._with_char_embeddings:
self._conv_char_embedding = embedding.ConvCharEmbeddingModule(
len(shared_resources.char_vocab), size)
self._embedding_projection = nn.Linear(size + input_size, size)
self._embedding_highway = Highway(size, 1)
self._v_wiq_w = nn.Parameter(torch.ones(1, 1, input_size + size))
input_size = size
else:
self._v_wiq_w = nn.Parameter(torch.ones(1, 1, input_size))
self._bilstm = BiLSTM(input_size + 2, size)
self._answer_layer = FastQAAnswerModule(shared_resources)
# [size, 2 * size]
self._question_projection = nn.Parameter(torch.cat([torch.eye(size), torch.eye(size)], dim=1))
self._support_projection = nn.Parameter(torch.cat([torch.eye(size), torch.eye(size)], dim=1))
评论列表
文章目录