distributed_training_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensorport-template 作者: tensorport 项目源码 文件源码
def _map(self, example_serialized):
        def _parse(line):
            input = np.float32(line)
            # a simple equation of the input to generate the output
            output = input + 10 + input * 2
            # generate 2 inputs and 1 output
            return input, np.float32(input * 3), np.float32(output)

        input_1, input_2, output = tf.py_func(func=_parse,
                                              inp=[example_serialized],
                                              Tout=[tf.float32, tf.float32, tf.float32],
                                              stateful=True)
        # set shapes for data
        input_1 = tf.reshape(input_1, [1])
        input_2 = tf.reshape(input_2, [1])
        output = tf.reshape(output, [1])
        # we could perform this operation here or in the graph
        input = tf.concat([input_1, input_2], axis=0)
        return input, output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号