def __init__(self, vocab_size, embedding_dim, igru_state_dim, igru_depth, trg_dgru_depth, emitter,
feedback_brick, merge=None, merge_prototype=None, post_merge=None, **kwargs):
merged_dim = igru_state_dim
if not merge:
merge = Merge(input_names=kwargs['source_names'],
prototype=merge_prototype)
if not post_merge:
post_merge = Bias(dim=merged_dim)
# for compatible
if igru_depth == 1:
self.igru = IGRU(dim=igru_state_dim)
else:
self.igru = RecurrentStack([IGRU(dim=igru_state_dim, name='igru')] +
[UpperIGRU(dim=igru_state_dim, activation=Tanh(), name='upper_igru' + str(i))
for i in range(1, igru_depth)],
skip_connections=True)
self.embedding_dim = embedding_dim
self.emitter = emitter
self.feedback_brick = feedback_brick
self.merge = merge
self.post_merge = post_merge
self.merged_dim = merged_dim
self.igru_depth = igru_depth
self.trg_dgru_depth = trg_dgru_depth
self.lookup = LookupTable(name='embeddings')
self.vocab_size = vocab_size
self.igru_state_dim = igru_state_dim
self.gru_to_softmax = Linear(input_dim=igru_state_dim, output_dim=vocab_size)
self.gru_fork = Fork([name for name in self.igru.apply.sequences
if name != 'mask' and name != 'input_states'], prototype=Linear(), name='gru_fork')
children = [self.emitter, self.feedback_brick, self.merge, self.post_merge,
self.igru, self.lookup, self.gru_to_softmax, self.gru_fork]
kwargs.setdefault('children', []).extend(children)
super(Interpolator, self).__init__(**kwargs)
评论列表
文章目录