data.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tfnn 作者: MorvanZhou 项目源码 文件源码
def __init__(self, xs, ys, name=None):
        """
        Input data sets.
        :param xs: data, shape(n_samples, n_xs), accept numpy, pandas, list
        :param ys: labels, shape(n_samples, n_ys), accept numpy, pandas, list
        """
        if (type(xs).__module__ == np.__name__) & (type(ys).__module__ == np.__name__):
            self.module = 'numpy_data'
        elif ('pandas' in type(xs).__module__) & ('pandas' in type(ys).__module__):
            xs, ys = np.asarray(xs), np.asarray(ys)
        elif (type(xs) == list) & (type(ys) == list):
            xs, ys = np.asarray(xs), np.asarray(ys)
        else:
            raise TypeError('all data type must be numpy or pandas')
        if ys.ndim < 2:
            ys = ys[:, np.newaxis]
        if xs.ndim < 2:
            xs = xs[:, np.newaxis]

        self.n_xfeatures = xs.shape[-1]     # col for 2 dims, channel for 3 dims
        self.n_yfeatures = ys.shape[-1]     # col for 2 dims,
        self.data = np.hstack((xs, ys))
        self.n_samples = ys.shape[0]
        self.name = name
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号