tf_tree_lstm.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:RecursiveNN 作者: sapruash 项目源码 文件源码
def add_embedding(self):

        #embed=np.load('glove{0}_uniform.npy'.format(self.emb_dim))
        with tf.variable_scope("Embed",regularizer=None):
            embedding=tf.get_variable('embedding',[self.num_emb,
                                                   self.emb_dim]
                        ,initializer=tf.random_uniform_initializer(-0.05,0.05),trainable=True,regularizer=None)
            ix=tf.to_int32(tf.not_equal(self.input,-1))*self.input
            emb_tree=tf.nn.embedding_lookup(embedding,ix)
            emb_tree=emb_tree*(tf.expand_dims(
                        tf.to_float(tf.not_equal(self.input,-1)),2))

            return emb_tree
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号