model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lstm-poetry 作者: dvictor 项目源码 文件源码
def sample(self, sess, chars, vocab, num, prime, temperature):
        state = self.cell.zero_state(1, tf.float32).eval()
        for char in prime[:-1]:
            x = np.zeros((1, 1))
            x[0, 0] = vocab[char]
            feed = {self.input_data: x, self.initial_state: state}
            [state] = sess.run([self.final_state], feed)

        def weighted_pick(a):
            a = a.astype(np.float64)
            a = a.clip(min=1e-20)
            a = np.log(a) / temperature
            a = np.exp(a) / (np.sum(np.exp(a)))
            return np.argmax(np.random.multinomial(1, a, 1))

        char = prime[-1]
        for n in range(num):
            x = np.zeros((1, 1))
            x[0, 0] = vocab[char]
            feed = {self.input_data: x, self.initial_state: state}
            [probs, state] = sess.run([self.probs, self.final_state], feed)
            p = probs[0]
            sample = weighted_pick(p)
            char = chars[sample]
            yield char
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号