def __init__(self, hidden_size, vocab_size, num_hop=3, qa=None):
super(DMNPlus, self).__init__()
self.num_hop = num_hop
self.qa = qa
self.word_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0, sparse=True).cuda()
init.uniform(self.word_embedding.state_dict()['weight'], a=-(3**0.5), b=3**0.5)
self.criterion = nn.CrossEntropyLoss(size_average=False)
self.input_module = InputModule(vocab_size, hidden_size)
self.question_module = QuestionModule(vocab_size, hidden_size)
self.memory = EpisodicMemory(hidden_size)
self.answer_module = AnswerModule(vocab_size, hidden_size)
babi_main.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录