VisualAttention.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:cs234_final_project 作者: nipunagarwala 项目源码 文件源码
def __init__(self, features_shape, num_classes, seq_len, cell_type='lstm', reuse=False, add_bn=False,
                 add_reg=False, scope="VA"):
        self.config = VisualAttentionConfig()
        self.config.features_shape = features_shape
        self.config.num_classes = num_classes
        self.reuse = reuse
        self.inputs_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,)+ self.config.features_shape ))
        self.init_loc = tf.placeholder(tf.float32, shape=tuple((None,)+ self.config.init_loc_size))
        self.targets_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,) + self.config.targets_shape))
        self.config.seq_len = seq_len
        self.seq_len_placeholder = tf.placeholder(tf.int32, shape=tuple((None,) ))
        self.emission_num_layers =  1

        self.loss_type = 'negative_l1_dist'
        self.cumsum = False

        self.scope = scope
        if add_bn:
            self.norm_fn = tf.contrib.layers.batch_norm
        else:
            self.norm_fn = None

        if add_reg:
            self.reg_fn = tf.nn.l2_loss
        else:
            self.reg_fn = None

        if cell_type == 'rnn':
            self.cell = tf.contrib.rnn.RNNCell
        elif cell_type == 'gru':
            self.cell = tf.contrib.rnn.GRUCell
        elif cell_type == 'lstm':
            self.cell = tf.contrib.rnn.LSTMCell
        else:
            raise ValueError('Input correct cell type')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号