def __init__(self, features_shape, num_classes, seq_len, cell_type='lstm', reuse=False, add_bn=False,
add_reg=False, scope="VA"):
self.config = VisualAttentionConfig()
self.config.features_shape = features_shape
self.config.num_classes = num_classes
self.reuse = reuse
self.inputs_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,)+ self.config.features_shape ))
self.init_loc = tf.placeholder(tf.float32, shape=tuple((None,)+ self.config.init_loc_size))
self.targets_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,) + self.config.targets_shape))
self.config.seq_len = seq_len
self.seq_len_placeholder = tf.placeholder(tf.int32, shape=tuple((None,) ))
self.emission_num_layers = 1
self.loss_type = 'negative_l1_dist'
self.cumsum = False
self.scope = scope
if add_bn:
self.norm_fn = tf.contrib.layers.batch_norm
else:
self.norm_fn = None
if add_reg:
self.reg_fn = tf.nn.l2_loss
else:
self.reg_fn = None
if cell_type == 'rnn':
self.cell = tf.contrib.rnn.RNNCell
elif cell_type == 'gru':
self.cell = tf.contrib.rnn.GRUCell
elif cell_type == 'lstm':
self.cell = tf.contrib.rnn.LSTMCell
else:
raise ValueError('Input correct cell type')
VisualAttention.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录