nn_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:imperative 作者: yaroslavvb 项目源码 文件源码
def _Inputs(self, x=None, y=None, q=3.0, dtype=tf.float64, sizes=None):
    x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
    y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
    assert len(x) == len(y)
    sizes = sizes if sizes else [len(x)]
    logits = tf.constant(x, shape=sizes, dtype=dtype, name="logits")
    targets = tf.constant(y, shape=sizes, dtype=dtype, name="targets")
    losses = np.array(self._WeightedCrossEntropy(x, y, q)).reshape(*sizes)
    return logits, targets, q, losses

  # def testConstructionNamed(self):
  #   with self.test_session():
  #     logits, targets, pos_weight, _ = self._Inputs()
  #     loss = tf.nn.weighted_cross_entropy_with_logits(logits, targets,
  #                                                     pos_weight, name="mybce")
  #   self.assertEqual("mybce", loss.op.name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号