def sym_gen_word(bucket_key):
key = bucket_key.split(',')
tw_length = int(key[0])
cw_length = int(key[1])
tw_data = mx.sym.Variable('tw_array')
cw_data = mx.sym.Variable('cw_array')
label = mx.sym.Variable('label')
tw_slices = list(mx.symbol.SliceChannel(data=tw_data, axis=1, num_outputs=tw_length, squeeze_axis=True, name='tw_slice'))
cw_slices = list(mx.symbol.SliceChannel(data=cw_data, axis=1, num_outputs=cw_length, squeeze_axis=True, name='cw_slice'))
tw_concat, _ = tw_cell.unroll(tw_length, inputs = tw_slices, merge_outputs=True, layout='TNC')
cw_concat, _ = cw_cell.unroll(cw_length, inputs = cw_slices, merge_outputs=True, layout='TNC')
tw_concat = mx.sym.transpose(tw_concat, (1, 2, 0))
cw_concat = mx.sym.transpose(cw_concat, (1, 2, 0))
tw_concat = mx.sym.Pooling(tw_concat, kernel=(1,), global_pool = True, pool_type='max')
cw_concat = mx.sym.Pooling(cw_concat, kernel=(1,), global_pool = True, pool_type='max')
feature = mx.sym.Concat(*[tw_concat, cw_concat], name= 'concat')
feature = fc_module(feature, 'fc1', num_hidden=1024)
feature = fc_module(feature, 'fc2', num_hidden=1024)
feature = mx.sym.Dropout(feature, p=0.5)
feature = fc_module(feature, 'feature', num_hidden=2000)
loss = mx.sym.LogisticRegressionOutput(feature, label=label, name='regression')
return loss, data_name, label_name
#mod = mx.module.BucketingModule(sym_gen_word, default_bucket_key=ziter.max_bucket_key,context=mx.gpu(1),data_names=data_name, label_names=label_name)
评论列表
文章目录