model_reader.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:context2vec 作者: orenmel 项目源码 文件源码
def read_lstm_model(self, params, train):

        assert train == False # reading a model to continue training is currently not supported

        words_file = params['config_path'] + params['words_file']
        model_file = params['config_path'] + params['model_file']
        unit = int(params['unit'])
        deep = (params['deep'] == 'yes')
        drop_ratio = float(params['drop_ratio'])

        #read and normalize target word embeddings
        w, word2index, index2word = self.read_words(words_file) 
        s = numpy.sqrt((w * w).sum(1))
        s[s==0.] = 1.
        w /= s.reshape((s.shape[0], 1))  # normalize

        context_word_units = unit
        lstm_hidden_units = IN_TO_OUT_UNITS_RATIO*unit
        target_word_units = IN_TO_OUT_UNITS_RATIO*unit

        cs = [1 for _ in range(len(word2index))] # dummy word counts - not used for eval
        loss_func = L.NegativeSampling(target_word_units, cs, NEGATIVE_SAMPLING_NUM) # dummy loss func - not used for eval

        model = BiLstmContext(deep, self.gpu, word2index, context_word_units, lstm_hidden_units, target_word_units, loss_func, train, drop_ratio)
        S.load_npz(model_file, model)

        return w, word2index, index2word, model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号