def _load_cnn_weights(self):
cnn_options = self._options['char_cnn']
filters = cnn_options['filters']
char_embed_dim = cnn_options['embedding']['dim']
convolutions = []
for i, (width, num) in enumerate(filters):
conv = torch.nn.Conv1d(
in_channels=char_embed_dim,
out_channels=num,
kernel_size=width,
bias=True
)
# load the weights
with h5py.File(cached_path(self._weight_file), 'r') as fin:
weight = fin['CNN']['W_cnn_{}'.format(i)][...]
bias = fin['CNN']['b_cnn_{}'.format(i)][...]
w_reshaped = numpy.transpose(weight.squeeze(axis=0), axes=(2, 1, 0))
if w_reshaped.shape != tuple(conv.weight.data.shape):
raise ValueError("Invalid weight file")
conv.weight.data.copy_(torch.FloatTensor(w_reshaped))
conv.bias.data.copy_(torch.FloatTensor(bias))
conv.weight.requires_grad = False
conv.bias.requires_grad = False
convolutions.append(conv)
self.add_module('char_conv_{}'.format(i), conv)
self._convolutions = convolutions
评论列表
文章目录