def init_weight_descriptor(fn, weight): w_desc = cudnn.FilterDescriptor() w_view = weight.view(-1, 1, 1) # seems that filters require >=3 dimensions w_desc.set(w_view) return w_desc