def __init__(self, features_shape, num_classes, cell_type='lstm', seq_len=8, reuse=False, add_bn=False,
add_reg=False, deeper = False, loss_type = 'negative_l1_dist', cum_sum=False, scope='RCNN'):
self.config = RecurrentCNNConfig()
self.config.features_shape = features_shape
self.config.num_classes = num_classes
self.reuse = reuse
self.inputs_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,)+ self.config.features_shape ))
self.init_loc = tf.placeholder(tf.float32, shape=tuple((None,)+ self.config.init_loc_size))
self.targets_placeholder = tf.placeholder(tf.float32, shape=tuple((None,None,) + self.config.targets_shape))
self.config.seq_len = seq_len
self.seq_len_placeholder = tf.placeholder(tf.int32, shape=tuple((None,) ))
self.deeper = deeper
self.loss_type = loss_type
self.cumsum = cum_sum
self.scope = scope
if add_bn:
self.norm_fn = tf.contrib.layers.batch_norm
else:
self.norm_fn = None
if add_reg:
self.reg_fn = tf.nn.l2_loss
else:
self.reg_fn = None
if cell_type == 'rnn':
self.cell = tf.contrib.rnn.RNNCell
elif cell_type == 'gru':
self.cell = tf.contrib.rnn.GRUCell
elif cell_type == 'lstm':
self.cell = tf.contrib.rnn.LSTMCell
else:
raise ValueError('Input correct cell type')
评论列表
文章目录