def define_model(x,
keep_prob,
number_of_classes,
number_of_filters,
number_of_fc_features):
splitted = tf.unpack(x, axis=4)
branches = []
with tf.variable_scope('branches') as scope:
for index, tensor_slice in enumerate(splitted):
branches.append(single_branch(splitted[index],
number_of_filters,
number_of_fc_features))
if (index == 0):
scope.reuse_variables()
concatenated = tf.pack(branches, axis=2)
ti_pooled = tf.reduce_max(concatenated, reduction_indices=[2])
drop = tf.nn.dropout(ti_pooled, keep_prob)
with tf.variable_scope('fc2'):
logits = fc(drop,
[number_of_fc_features, number_of_classes],
[number_of_classes])
return logits
评论列表
文章目录