def fit(self,
data_x_train,
data_x_dev=None,
data_x_test=None,
n_epochs=10,
batch_size=10):
assert n_epochs > 0
assert batch_size < data_x_train.shape[0]
size_x_train = data_x_train.shape[0]
n_batches = size_x_train / batch_size
for e in range(n_epochs):
epoch_costs = np.zeros(n_batches)
bar = tqdm(range(n_batches), desc='Epoch: {:d}'.format(e))
for i in bar:
batch_x = data_x_train[i*batch_size:(i+1)*batch_size]
err = self.partial_fit(batch_x)
epoch_costs[i] = err
mean_cost = epoch_costs.mean()
print 'Train error: {:.4f}'.format(mean_cost)
if data_x_dev is not None:
random_indices = np.random.randint(0, data_x_dev.shape[0], batch_size)
batch_x = data_x_dev[random_indices]
err = self.get_cost(batch_x)
print 'Validation data error: {:.4f}'.format(err)
if data_x_test is not None:
err = self.get_cost(data_x_test)
print 'Test data error: {:.4f}'.format(err)
评论列表
文章目录