def build_rnn(conv_input_var, seq_input_var, conv_shape, word_dims, n_hid, lstm_layers):
ret = {}
ret['seq_input'] = seq_layer = InputLayer((None, None, word_dims), input_var=seq_input_var)
batchsize, seqlen, _ = seq_layer.input_var.shape
ret['seq_resh'] = seq_layer = ReshapeLayer(seq_layer, shape=(-1, word_dims))
ret['seq_proj'] = seq_layer = DenseLayer(seq_layer, num_units=n_hid)
ret['seq_resh2'] = seq_layer = ReshapeLayer(seq_layer, shape=(batchsize, seqlen, n_hid))
ret['conv_input'] = conv_layer = InputLayer(conv_shape, input_var = conv_input_var)
ret['conv_proj'] = conv_layer = DenseLayer(conv_layer, num_units=n_hid)
ret['conv_resh'] = conv_layer = ReshapeLayer(conv_layer, shape=([0], 1, -1))
ret['input_concat'] = layer = ConcatLayer([conv_layer, seq_layer], axis=1)
for lstm_layer_idx in xrange(lstm_layers):
ret['lstm_{}'.format(lstm_layer_idx)] = layer = LSTMLayer(layer, n_hid)
ret['out_resh'] = layer = ReshapeLayer(layer, shape=(-1, n_hid))
ret['output_proj'] = layer = DenseLayer(layer, num_units=word_dims, nonlinearity=log_softmax)
ret['output'] = layer = ReshapeLayer(layer, shape=(batchsize, seqlen+1, word_dims))
ret['output'] = layer = SliceLayer(layer, indices=slice(None, -1), axis=1)
return ret
# originally from
# https://github.com/Lasagne/Recipes/blob/master/examples/styletransfer/Art%20Style%20Transfer.ipynb
评论列表
文章目录