def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):
""" must call construct_model() after initializing MAML! """
self.dim_input = dim_input
self.dim_output = dim_output
self.update_lr = FLAGS.update_lr
self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
self.classification = False
self.test_num_updates = test_num_updates
if FLAGS.datasource == 'sinusoid':
self.dim_hidden = [40, 40]
self.loss_func = mse
self.forward = self.forward_fc
self.construct_weights = self.construct_fc_weights
elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet':
self.loss_func = xent
self.classification = True
if FLAGS.conv:
self.dim_hidden = FLAGS.num_filters
self.forward = self.forward_conv
self.construct_weights = self.construct_conv_weights
else:
self.dim_hidden = [256, 128, 64, 64]
self.forward=self.forward_fc
self.construct_weights = self.construct_fc_weights
if FLAGS.datasource == 'miniimagenet':
self.channels = 3
else:
self.channels = 1
self.img_size = int(np.sqrt(self.dim_input/self.channels))
else:
raise ValueError('Unrecognized data source.')
评论列表
文章目录