def discriminatorResNet(x, hidden_num, output_dim, kern_size, in_channels, reuse):
with tf.variable_scope("D") as vs:
if reuse:
vs.reuse_variables()
conv = tcl.conv2d(x, hidden_num, kernel_size=1)
res1 = resBlock(conv, hidden_num, kern_size)
res2 = resBlock(res1, hidden_num, kern_size)
res3 = resBlock(res2, hidden_num, kern_size)
res4 = resBlock(res3, hidden_num, kern_size)
res5 = resBlock(res4, hidden_num, kern_size)
res5 = tf.reshape(res5, [-1, output_dim*hidden_num]) # data_format: 'NWC'
disc_out = tcl.fully_connected(res5, 1, activation_fn=None)
d_vars = tf.contrib.framework.get_variables(vs)
return disc_out, d_vars
评论列表
文章目录