test_cross_entropy_graph.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:dnn-quant 作者: euclidjda 项目源码 文件源码
def create_graph(g):
    initer = tf.random_uniform_initializer(0.0,INIT_SCALE)

    with tf.variable_scope("graph", reuse=None, initializer=initer):
        g['x'] = list()
        g['y'] = list()
        g['s'] = list()
        g['seq_lengths'] = tf.placeholder(tf.int64,shape=[BATCH_SIZE]);

        for _ in range(UNROLLS):
            g['x'].append( tf.placeholder(tf.float32,shape=[BATCH_SIZE,INPUT_SIZE]) )
            g['y'].append( tf.placeholder(tf.float32,shape=[BATCH_SIZE,INPUT_SIZE]) )
            g['s'].append( tf.placeholder(tf.float32,shape=[BATCH_SIZE]) )

        num_inputs  = INPUT_SIZE * UNROLLS
        # num_outputs = OUTPUT_SIZE * UNROLLS

        g['w'] = tf.get_variable("softmax_w", [num_inputs,OUTPUT_SIZE])
        g['b'] = tf.get_variable("softmax_b", [OUTPUT_SIZE])

        g['cat_x'] = tf.concat(1, g['x'] )

        g['logits'] = tf.nn.xw_plus_b(g['cat_x'], g['w'], g['b'] )

        g['cat_y'] = tf.unpack(tf.reverse_sequence(tf.reshape( tf.concat(1, g['y'] ),
            [BATCH_SIZE,UNROLLS,OUTPUT_SIZE] ),g['seq_lengths'],1,0),axis=1)[0]

        g['loss'] = tf.nn.softmax_cross_entropy_with_logits(g['logits'], g['cat_y'])

        g['r_s'] = tf.unpack(tf.reverse_sequence(tf.transpose(
            tf.reshape( tf.concat(0, g['s'] ), [UNROLLS, BATCH_SIZE] ) ),
            g['seq_lengths'],1,0),axis=1)[0]

        g['train_loss'] = tf.mul( g['loss'], g['r_s'] )

        g['preds'] = tf.nn.softmax(g['logits'])

        g['class_preds'] =  tf.floor( g['preds'] + 0.5 )

        g['accy'] = tf.mul( g['class_preds'],  g['cat_y'] )

        g['w_accy'] = tf.mul(g['accy'], tf.reshape(g['r_s'],shape=[BATCH_SIZE,1]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号