def HAN_model_1(session, restore_only=False):
"""Hierarhical Attention Network"""
import tensorflow as tf
try:
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, DropoutWrapper
except ImportError:
MultiRNNCell = tf.nn.rnn_cell.MultiRNNCell
GRUCell = tf.nn.rnn_cell.GRUCell
from bn_lstm import BNLSTMCell
from HAN_model import HANClassifierModel
is_training = tf.placeholder(dtype=tf.bool, name='is_training')
cell = BNLSTMCell(80, is_training) # h-h batchnorm LSTMCell
# cell = GRUCell(30)
cell = MultiRNNCell([cell]*5)
model = HANClassifierModel(
vocab_size=vocab_size,
embedding_size=200,
classes=classes,
word_cell=cell,
sentence_cell=cell,
word_output_size=100,
sentence_output_size=100,
device=args.device,
learning_rate=args.lr,
max_grad_norm=args.max_grad_norm,
dropout_keep_proba=0.5,
is_training=is_training,
)
saver = tf.train.Saver(tf.global_variables())
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint:
print("Reading model parameters from %s" % checkpoint.model_checkpoint_path)
saver.restore(session, checkpoint.model_checkpoint_path)
elif restore_only:
raise FileNotFoundError("Cannot restore model")
else:
print("Created model with fresh parameters")
session.run(tf.global_variables_initializer())
# tf.get_default_graph().finalize()
return model, saver
评论列表
文章目录