models.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:twitter_caption 作者: tencia 项目源码 文件源码
def build_rnn(conv_input_var, seq_input_var, conv_shape, word_dims, n_hid, lstm_layers):
    ret = {}
    ret['seq_input'] = seq_layer = InputLayer((None, None, word_dims), input_var=seq_input_var)
    batchsize, seqlen, _ = seq_layer.input_var.shape
    ret['seq_resh'] = seq_layer = ReshapeLayer(seq_layer, shape=(-1, word_dims))
    ret['seq_proj'] = seq_layer = DenseLayer(seq_layer, num_units=n_hid)
    ret['seq_resh2'] = seq_layer = ReshapeLayer(seq_layer, shape=(batchsize, seqlen, n_hid))
    ret['conv_input'] = conv_layer = InputLayer(conv_shape, input_var = conv_input_var)
    ret['conv_proj'] = conv_layer = DenseLayer(conv_layer, num_units=n_hid)
    ret['conv_resh'] = conv_layer = ReshapeLayer(conv_layer, shape=([0], 1, -1))
    ret['input_concat'] = layer = ConcatLayer([conv_layer, seq_layer], axis=1)
    for lstm_layer_idx in xrange(lstm_layers):
        ret['lstm_{}'.format(lstm_layer_idx)] = layer = LSTMLayer(layer, n_hid)
    ret['out_resh'] = layer = ReshapeLayer(layer, shape=(-1, n_hid))
    ret['output_proj'] = layer = DenseLayer(layer, num_units=word_dims, nonlinearity=log_softmax)
    ret['output'] = layer = ReshapeLayer(layer, shape=(batchsize, seqlen+1, word_dims))
    ret['output'] = layer = SliceLayer(layer, indices=slice(None, -1), axis=1)
    return ret

# originally from
# https://github.com/Lasagne/Recipes/blob/master/examples/styletransfer/Art%20Style%20Transfer.ipynb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号