def __init__(self):
et = EntityTracker()
self.bow_enc = BoW_encoder()
self.emb = UtteranceEmbed()
at = ActionTracker(et)
self.dataset, dialog_indices = Data(et, at).trainset
train_indices = joblib.load('data/train_test_list/train_indices_759')
test_indices = joblib.load('data/train_test_list/test_indices_759_949')
self.dialog_indices_tr = train_indices
self.dialog_indices_dev = test_indices
obs_size = self.emb.dim + self.bow_enc.vocab_size + et.num_features
self.action_templates = at.get_action_templates()
action_size = at.action_size
nb_hidden = 128
self.net = LSTM_net(obs_size=obs_size,
action_size=action_size,
nb_hidden=nb_hidden)
评论列表
文章目录