RestNet.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:ResNet-deeplabV3 作者: Harvey1973 项目源码 文件源码
def output_layer (input_layer, num_labels):
    '''
    param input_layer : flattend 2D tensor
    param num_lables: number of classes
    return the output of FC layer : Y =Wx+b
    '''
    input_dim = input_layer.get_shape().as_list()[-1]
    fc_w = create_variables(name = 'fc_weight',shape = [input_dim,num_labels],is_fc_layer = True,initializer=tf.uniform_unit_scaling_initializer(factor=1.0))
    fc_b = create_variables(name = 'fc_bias',shape = [num_labels],is_fc_layer = False,initializer = tf.zeros_initializer())
    output = tf.matmul(input_layer,fc_w) + fc_b
    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号