network.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:GC-Net 作者: Jiankai-Sun 项目源码 文件源码
def deconv_3d(x, c):
  ksize = c['ksize']
  stride = c['stride']
  filters_out = c['conv_filters_out']
  filters_in = x.get_shape()[-1]
  # must have as_list to get a python list!!!!!!!!!!!!!!
  x_shape = x.get_shape().as_list()
  d = x_shape[1] * stride
  height = x_shape[2] * stride
  width = x_shape[3] * stride
  output_shape = [1, d, height, width, filters_out]
  strides = [1, stride, stride, stride, 1]
  shape = [ksize, ksize, ksize, filters_out, filters_in]
  # initializer = tf.truncated_normal_initializer(stddev=CONV_WEIGHT_STDDEV)
  initializer = tf.contrib.layers.xavier_initializer()
  weights = _get_variable('weights',
                          shape=shape,
                          dtype='float32',
                          initializer=initializer,
                          weight_decay=CONV_WEIGHT_DECAY)
  bias = tf.get_variable('bias', [filters_out], 'float32', tf.constant_initializer(0.05, dtype='float32'))
  x = tf.nn.conv3d_transpose(x, weights, output_shape=output_shape, strides=strides, padding='SAME')
  return tf.nn.bias_add(x, bias)

# wrapper for batch-norm op
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号