def __init__(self, config, batch_size, one_hot=False):
self.lookup = None
reader = tf.TextLineReader()
filename_queue = tf.train.string_input_producer(["chargan.txt"])
key, x = reader.read(filename_queue)
vocabulary = self.get_vocabulary()
table = tf.contrib.lookup.string_to_index_table_from_tensor(
mapping = vocabulary, default_value = 0)
x = tf.string_join([x, tf.constant(" " * 64)])
x = tf.substr(x, [0], [64])
x = tf.string_split(x,delimiter='')
x = tf.sparse_tensor_to_dense(x, default_value=' ')
x = tf.reshape(x, [64])
x = table.lookup(x)
self.one_hot = one_hot
if one_hot:
x = tf.one_hot(x, len(vocabulary))
x = tf.cast(x, dtype=tf.float32)
x = tf.reshape(x, [1, int(x.get_shape()[0]), int(x.get_shape()[1]), 1])
else:
x = tf.cast(x, dtype=tf.float32)
x -= len(vocabulary)/2.0
x /= len(vocabulary)/2.0
x = tf.reshape(x, [1,1, 64, 1])
num_preprocess_threads = 8
x = tf.train.shuffle_batch(
[x],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity= 5000,
min_after_dequeue=500,
enqueue_many=True)
self.x = x
self.table = table
评论列表
文章目录