train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:korean_restaurant_reservation 作者: JudeLee19 项目源码 文件源码
def __init__(self):

        et = EntityTracker()
        self.bow_enc = BoW_encoder()
        self.emb = UtteranceEmbed()
        at = ActionTracker(et)

        self.dataset, dialog_indices = Data(et, at).trainset

        train_indices = joblib.load('data/train_test_list/train_indices_759')
        test_indices = joblib.load('data/train_test_list/test_indices_759_949')

        self.dialog_indices_tr = train_indices
        self.dialog_indices_dev = test_indices

        obs_size = self.emb.dim + self.bow_enc.vocab_size + et.num_features
        self.action_templates = at.get_action_templates()
        action_size = at.action_size
        nb_hidden = 128

        self.net = LSTM_net(obs_size=obs_size,
                       action_size=action_size,
                       nb_hidden=nb_hidden)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号