tflintrans.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:vampyre 作者: GAMPTeam 项目源码 文件源码
def __init__(self,x_op,y_op,sess,remove_bias=False):
        # Save parameters
        self.x_op = x_op
        self.y_op = y_op
        self.sess = sess
        self.remove_bias = remove_bias

        # Get dimensions and data types
        self.shape0 = x_op.get_shape()
        self.shape1 = y_op.get_shape()
        self.dtype0 = x_op.dtype
        self.dtype1 = y_op.dtype

        # Create the ops for the gradient.  If the linear operator is y=F(x),
        # then z = y'*F(x).  Therefore, dz/dx = F'(y).
        self.ytr_op = tf.placeholder(self.dtype1,self.shape1)        
        self.z_op = tf.reduce_sum(tf.multiply(tf.conj(self.ytr_op),self.y_op))
        self.zgrad_op = tf.gradients(self.z_op,self.x_op)[0]

        # Compute output at zero to subtract 
        if self.remove_bias:
            xzero = np.zeros(self.shape0)
            self.y_bias = self.sess.run(self.y_op, feed_dict={self.x_op: xzero})
        else:
            self.y_bias = 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号