def __init__(self, rnn_cell, rnn_cell_dim, num_chars, bow_char, eow_char, logdir, expname, threads=1, seed=42):
# Create an empty graph and a session
graph = tf.Graph()
graph.seed = seed
self.session = tf.Session(graph = graph, config=tf.ConfigProto(inter_op_parallelism_threads=threads,
intra_op_parallelism_threads=threads))
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
self.summary_writer = tf.train.SummaryWriter("{}/{}-{}".format(logdir, timestamp, expname), flush_secs=10)
# Construct the graph
with self.session.graph.as_default():
if rnn_cell == "LSTM":
rnn_cell = tf.nn.rnn_cell.LSTMCell(rnn_cell_dim)
elif rnn_cell == "GRU":
rnn_cell = tf.nn.rnn_cell.GRUCell(rnn_cell_dim)
else:
raise ValueError("Unknown rnn_cell {}".format(rnn_cell))
self.global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name="global_step")
self.sentence_lens = tf.placeholder(tf.int32, [None])
self.form_ids = tf.placeholder(tf.int32, [None, None])
self.forms = tf.placeholder(tf.int32, [None, None])
self.form_lens = tf.placeholder(tf.int32, [None])
self.lemma_ids = tf.placeholder(tf.int32, [None, None])
self.lemmas = tf.placeholder(tf.int32, [None, None])
self.lemma_lens = tf.placeholder(tf.int32, [None])
# TODO
# loss = ...
# self.training = ...
# self.predictions = ...
# self.accuracy = ...
self.dataset_name = tf.placeholder(tf.string, [])
self.summary = tf.merge_summary([tf.scalar_summary(self.dataset_name+"/loss", loss),
tf.scalar_summary(self.dataset_name+"/accuracy", self.accuracy)])
# Initialize variables
self.session.run(tf.initialize_all_variables())
评论列表
文章目录