shalo_base.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:shalo 作者: henryre 项目源码 文件源码
def _build(self):
        assert(self.d is not None)
        assert(self.lr is not None)
        assert(self.l2_penalty is not None)
        assert(self.loss_function is not None)
        # Get input placeholders and sentence features
        self._create_placeholders()
        sentence_feats, save_kwargs = self._embed_sentences()
        # Define linear model
        s1, s2 = self.seed, (self.seed + 1 if self.seed is not None else None)
        w = tf.Variable(tf.random_normal((self.d, 1), stddev=SD, seed=s1))
        b = tf.Variable(tf.random_normal((1, 1), stddev=SD, seed=s2))
        h = tf.squeeze(tf.matmul(sentence_feats, w) + b)
        # Define training procedure
        self.loss       = self._get_loss(h, self.y)
        self.loss      += self.l2_penalty * tf.nn.l2_loss(w)
        self.prediction = tf.sigmoid(h)
        self.train_fn   = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
        self.save_dict  = save_kwargs.update({'w': w, 'b': b})
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号