def test_generator_training():
# parameters
file_name = "animals.txt"
genr_hidden_size = 10
disr_hidden_size = 3
num_epochs_d = 20
num_epochs_g = 20
lr = 1
alpha = 0.9
batch_size = 100
# load data
char_list = dataloader.get_char_list(file_name)
X_actual = dataloader.load_data(file_name)
num_examples = X_actual.shape[0]
seq_len = X_actual.shape[1]
# generate
genr_input = np.random.randn(num_examples, len(char_list))
genr = Generator(genr_hidden_size, char_list)
X_generated = genr.generate_tensor(seq_len, num_examples, genr_input)
# train discriminator
disr = Discriminator(len(char_list), disr_hidden_size)
disr.train_RMS(X_actual, X_generated, num_epochs_d, lr, alpha, batch_size)
# evaluate discriminator
accuracy = disr.accuracy(X_actual, X_generated)
print("accuracy: ", accuracy)
# train generator
genr.train_RMS(genr_input, seq_len, disr, num_epochs_g, 1, lr, alpha,
batch_size, print_progress=True)
# evaluate discriminator again
X_generated = genr.generate_tensor(seq_len, num_examples, genr_input)
accuracy = disr.accuracy(X_actual, X_generated)
print("accuracy: ", accuracy)
testing.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录