lstm.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:LSTM_PIT 作者: snsun 项目源码 文件源码
def get_opt_output(self):
    cost1 = tf.reduce_sum(tf.pow(self._cleaned1-self._labels1,2),2)+tf.reduce_sum(tf.pow(self._cleaned2-self._labels2,2),2)

        cost2 = tf.reduce_sum(tf.pow(self._cleaned2-self._labels1,2),2)+tf.reduce_sum(tf.pow(self._cleaned1-self._labels2,2),2)    

    idx = tf.slice(cost1, [0, 0], [1, -1]) > tf.slice(cost2, [0, 0], [1, -1])
    idx = tf.cast(idx, tf.float32)
    idx = tf.reduce_mean(idx,reduction_indices=0)
        idx = tf.reshape(idx, [tf.shape(idx)[0], 1])    
    x1 = self._cleaned1[0,:,:] * (1-idx) + self._cleaned2[0,:, :]*idx

    x2 = self._cleaned1[0,:,:]*idx + self._cleaned2[0,:,:]*(1-idx)
    row = tf.shape(x1)[0]
    col = tf.shape(x1)[1]
    x1 = tf.reshape(x1, [1, row, col])
    x2 = tf.reshape(x2, [1, row, col])
    return x1, x2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号