Networks.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:MLPractices 作者: carefree0910 项目源码 文件源码
def predict(self, x):
        x = NNDist._transfer_x(np.asarray(x))
        rs = []
        batch_size = floor(1e6 / np.prod(x.shape[1:]))
        epoch = int(ceil(len(x) / batch_size))
        output = self._sess.graph.get_tensor_by_name(self._output)
        bar = ProgressBar(max_value=epoch, name="Predict")
        bar.start()
        for i in range(epoch):
            if i == epoch - 1:
                rs.append(self._sess.run(output, {
                    self._entry: x[i * batch_size:]
                }))
            else:
                rs.append(self._sess.run(output, {
                    self._entry: x[i * batch_size:(i + 1) * batch_size]
                }))
            bar.update()
        return np.vstack(rs).astype(np.float32)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号