def sample(self, sess, chars, vocab, num, prime, temperature):
state = self.cell.zero_state(1, tf.float32).eval()
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state: state}
[state] = sess.run([self.final_state], feed)
def weighted_pick(a):
a = a.astype(np.float64)
a = a.clip(min=1e-20)
a = np.log(a) / temperature
a = np.exp(a) / (np.sum(np.exp(a)))
return np.argmax(np.random.multinomial(1, a, 1))
char = prime[-1]
for n in range(num):
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state: state}
[probs, state] = sess.run([self.probs, self.final_state], feed)
p = probs[0]
sample = weighted_pick(p)
char = chars[sample]
yield char
评论列表
文章目录