insqa_lstm.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:insuranceQA-cnn-lstm 作者: cszhz 项目源码 文件源码
def _cnn_net(self, tparams, cnn_input, batch_size, sequence_len, num_filters, filter_sizes, proj_size):
    outputs = []
    for filter_size in filter_sizes:
        filter_shape = (num_filters, 1, filter_size, proj_size)
        image_shape = (batch_size, 1, sequence_len, proj_size)
        W = tparams['cnn_W_' + str(filter_size)]
        b = tparams['cnn_b_' + str(filter_size)]
        conv_out = conv2d(input=cnn_input, filters=W, filter_shape=filter_shape, input_shape=image_shape)
        pooled_out = pool.pool_2d(input=conv_out, ds=(sequence_len - filter_size + 1, 1), ignore_border=True, mode='max')
        pooled_active = T.tanh(pooled_out + b.dimshuffle('x', 0, 'x', 'x'))
        outputs.append(pooled_active)
    num_filters_total = num_filters * len(filter_sizes)
    output_tensor = T.reshape(T.concatenate(outputs, axis=1), [batch_size, num_filters_total])
    return output_tensor
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号