def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, gpu=False, oracle_init=False):
super(Generator, self).__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.gpu = gpu
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim)
self.gru2out = nn.Linear(hidden_dim, vocab_size)
# initialise oracle network with N(0,1)
# otherwise variance of initialisation is very small => high NLL for data sampled from the same model
if oracle_init:
for p in self.parameters():
init.normal(p, 0, 1)
评论列表
文章目录