RecurrentCNN.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:cs234_final_project 作者: nipunagarwala 项目源码 文件源码
def __init__(self, features_shape, num_classes, cell_type='lstm', seq_len=8, reuse=False, add_bn=False,
                add_reg=False, deeper = False, loss_type = 'negative_l1_dist', cum_sum=False, scope='RCNN'):
        self.config = RecurrentCNNConfig()
        self.config.features_shape = features_shape
        self.config.num_classes = num_classes
        self.reuse = reuse
        self.inputs_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,)+ self.config.features_shape ))
        self.init_loc = tf.placeholder(tf.float32, shape=tuple((None,)+ self.config.init_loc_size))
        self.targets_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,) + self.config.targets_shape))
        self.config.seq_len = seq_len
        self.seq_len_placeholder = tf.placeholder(tf.int32, shape=tuple((None,) ))
        self.deeper = deeper
        self.loss_type = loss_type
        self.cumsum = cum_sum

        self.scope = scope
        if add_bn:
            self.norm_fn = tf.contrib.layers.batch_norm
        else:
            self.norm_fn = None

        if add_reg:
            self.reg_fn = tf.nn.l2_loss
        else:
            self.reg_fn = None

        if cell_type == 'rnn':
            self.cell = tf.contrib.rnn.RNNCell
        elif cell_type == 'gru':
            self.cell = tf.contrib.rnn.GRUCell
        elif cell_type == 'lstm':
            self.cell = tf.contrib.rnn.LSTMCell
        else:
            raise ValueError('Input correct cell type')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号