test.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:zhihu_cup 作者: Godricly 项目源码 文件源码
def sym_gen_word(bucket_key):
    key = bucket_key.split(',')
    tw_length = int(key[0])
    cw_length = int(key[1])
    tw_data = mx.sym.Variable('tw_array')
    cw_data = mx.sym.Variable('cw_array')
    label   = mx.sym.Variable('label')
    tw_slices = list(mx.symbol.SliceChannel(data=tw_data, axis=1, num_outputs=tw_length, squeeze_axis=True, name='tw_slice'))
    cw_slices = list(mx.symbol.SliceChannel(data=cw_data, axis=1, num_outputs=cw_length, squeeze_axis=True, name='cw_slice'))
    tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
    cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
    tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0)) 
    cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0)) 
    tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
    cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
    feature = mx.sym.Concat(*[tw_concat, cw_concat], name= 'concat')
    feature = fc_module(feature, 'fc1', num_hidden=1024)
    feature = fc_module(feature, 'fc2', num_hidden=1024)
    feature = mx.sym.Dropout(feature, p=0.5)
    feature = fc_module(feature, 'feature', num_hidden=2000)
    loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
    return loss, data_name, label_name

#mod = mx.module.BucketingModule(sym_gen_word, default_bucket_key=ziter.max_bucket_key,context=mx.gpu(1),data_names=data_name, label_names=label_name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号