train_catastrophe_model_human.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:human-rl 作者: gsastry 项目源码 文件源码
def model(self, features, labels):
        x = features["observation"]
        x = tf.contrib.layers.convolution2d(x, 2, kernel_size=[3, 3], stride=[2, 2], activation_fn=tf.nn.elu)
        x = tf.contrib.layers.convolution2d(x, 2, kernel_size=[3, 3], stride=[2, 2], activation_fn=tf.nn.elu)
        actions = tf.one_hot(tf.reshape(features["action"],[-1]), depth=6, on_value=1.0, off_value=0.0, axis=1)
        x = tf.concat(1, [tf.contrib.layers.flatten(x),  actions])
        x = tf.contrib.layers.fully_connected(x, 100, activation_fn=tf.nn.elu)
        x = tf.contrib.layers.fully_connected(x, 100, activation_fn=tf.nn.elu)
        logits = tf.contrib.layers.fully_connected(x, 1, activation_fn=None)
        prediction = tf.sigmoid(logits, name="prediction")
        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits, tf.expand_dims(labels, axis=1)),name="loss")
        train_op = tf.contrib.layers.optimize_loss(
          loss, tf.contrib.framework.get_global_step(), optimizer='Adam',
          learning_rate=self.learning_rate)
        tf.add_to_collection('prediction', prediction)
        tf.add_to_collection('loss', loss)
        return prediction, loss, train_op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号