nn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:snake 作者: rhinech 项目源码 文件源码
def __init__(self,
                 name,
                 size_input__layer,
                 size_hidden_layer,
                 size_output_layer,
                 l2_coeff,
                 keep_prob,
                 optimizer='SGD'):
        """Make new tensors and connect them."""

        self.size_input__layer = size_input__layer
        self.size_hidden_layer = size_hidden_layer
        self.size_output_layer = size_output_layer
        self.keep_prob = keep_prob
        self.input__placeholder = tf.placeholder(tf.float32, shape=(None, size_input__layer))
        self.answer_placeholder = tf.placeholder(tf.float32, shape=(None, size_output_layer))
        self.learning_rate = tf.placeholder(tf.float32)
        self.inference_proc, l2_proc = self.inference(self.input__placeholder, name)
        self.loss_proc = NN.loss(self.inference_proc, l2_proc, l2_coeff, self.answer_placeholder)
        self.training_proc = NN.training(self.loss_proc, self.learning_rate, optimizer)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号