model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:FindYourCandy 作者: BrainPad 项目源码 文件源码
def __init__(self, features_size, num_classes, for_predict=False, hidden_size=3):
        self.hidden_size = hidden_size
        self.num_classes = num_classes

        with tf.variable_scope('transfer'):
            self.features = tf.placeholder(tf.float32, (None, features_size), name='features')
            self.label_ids = tf.placeholder(tf.int32, (None,), name='label_ids')

            try:
                ones_initializer = init_ops.ones_initializer()
            except TypeError:
                ones_initializer = init_ops.ones_initializer

            hidden = tf.contrib.layers.fully_connected(
                self.features,
                hidden_size,
                activation_fn=tf.nn.relu,
                weights_initializer=tf.contrib.layers.xavier_initializer(),
                biases_initializer=ones_initializer,
                trainable=True
            )

            self.keep_prob = tf.placeholder(tf.float32)
            hidden_drop = tf.nn.dropout(hidden, self.keep_prob)

            logits = tf.contrib.layers.fully_connected(
                hidden_drop,
                num_classes,
                weights_initializer=tf.contrib.layers.xavier_initializer(),
                biases_initializer=ones_initializer,
                trainable=True
            )

        if not for_predict:
            # add loss operation if initializing for training
            one_hot = tf.one_hot(self.label_ids, num_classes, name='target')
            self.loss_op = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits, one_hot)
            )

        self.softmax_op = tf.nn.softmax(logits)
        self.saver = tf.train.Saver()

        if not for_predict:
            # add train operation and summary operation if initializing for training
            # Optimizer
            with tf.variable_scope('optimizer'):
                self.global_step = tf.Variable(0, name='global_step', trainable=False)
            # Summaries
            with tf.variable_scope('summaries'):
                tf.scalar_summary('in sample loss', self.loss_op)
                self.summary_op = tf.merge_all_summaries()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号