model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def __init__(self, embedding_length):
    self._calculator_loom = CalculatorLoom(embedding_length)

    self._labels_placeholder = tf.placeholder(tf.float32)
    self._classifier_weights = tf.Variable(
        tf.truncated_normal([embedding_length, 3],
                            dtype=tf.float32,
                            stddev=1),
        name='classifier_weights')

    self._output_weights = tf.matmul(
        self._calculator_loom.output(), self._classifier_weights)
    self._loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        logits=self._output_weights, labels=self._labels_placeholder))

    self._true_labels = tf.argmax(self._labels_placeholder, dimension=1)
    self._prediction = tf.argmax(self._output_weights, dimension=1)

    self._accuracy = tf.reduce_mean(tf.cast(
        tf.equal(self._true_labels, self._prediction),
        dtype=tf.float32))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号