model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:crnn_tf 作者: liuhu-bigeye 项目源码 文件源码
def __init__(self,
               lr,
               vocabulary_size=1295,
               mparams=None):

    super(Model, self).__init__()
    self._params = ModelParams(vocabulary_size=vocabulary_size)
    self.lr = lr
    self.learning_rate = tf.placeholder(tf.float32, shape=[])

    self.images = tf.placeholder(dtype=tf.float32, shape=[None, 32, 100], name='images')
    self.seqs_length = tf.placeholder(dtype=tf.int32, shape=[None], name='seqs_length')
    self.targets = tf.sparse_placeholder(tf.int32, name='targets')

    self.ks = [3, 3, 3, 3, 3, 3, 2]
    self.ps = [1, 1, 1, 1, 1, 1, 0]
    self.nm = [64, 128, 256, 256, 512, 512, 512]
    self.nh = [256, 256]

    self.is_training = tf.placeholder(dtype=tf.bool, shape=[])
    self._build_training()

    # Saver
    with tf.device('/cpu:0'):
      self.saver = tf.train.Saver(tf.global_variables())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号