def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0,
is_train=None):
if args is None or (nest.is_sequence(args) and not args):
raise ValueError("`args` must be specified")
if not nest.is_sequence(args):
args = [args]
flat_args = [flatten(arg, 1) for arg in args]
if input_keep_prob < 1.0:
assert is_train is not None
flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg)
for arg in flat_args]
flat_out = _linear(flat_args, output_size, bias, bias_start=bias_start, scope=scope)
out = reconstruct(flat_out, args[0], 1)
if squeeze:
out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1])
if wd:
add_wd(wd)
return out
评论列表
文章目录