model_class.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:sketch_rnn_classification 作者: payalbajaj 项目源码 文件源码
def __init__(self, hps, gpu_mode=True, reuse=False):
    """Initializer for the SketchRNN model.

    Args:
       hps: a HParams object containing model hyperparameters
       gpu_mode: a boolean that when True, uses GPU mode.
       reuse: a boolean that when true, attemps to reuse variables.
    """
    self.hps = hps
    with tf.variable_scope('vector_rnn', reuse=reuse):
      if not gpu_mode:
        with tf.device('/cpu:0'):
          tf.logging.info('Model using cpu.')
          self.build_model(hps)
      else:
        tf.logging.info('Model using gpu.')
        self.build_model(hps)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号