def test_discriminator():
# parameters
file_name = "animals.txt"
genr_hidden_size = 10
disr_hidden_size = 11
num_epochs = 20
lr = 1
alpha = 0.9
batch_size = 100
# load data
char_list = dataloader.get_char_list(file_name)
X_actual = dataloader.load_data(file_name)
num_examples = X_actual.shape[0]
seq_len = X_actual.shape[1]
# generate
genr = Generator(genr_hidden_size, char_list)
X_generated = genr.generate_tensor(seq_len, num_examples)
# train discriminator
disr = Discriminator(len(char_list), disr_hidden_size)
disr.train_RMS(X_actual, X_generated, num_epochs, lr, alpha, batch_size,
print_progress=True)
# print discriminator output
outp = disr.discriminate(np.concatenate((X_actual, X_generated), axis=0))
print(outp)
# evaluate discriminator
accuracy = disr.accuracy(X_actual, X_generated)
print("accuracy: ", accuracy)
testing.py 文件源码
python
阅读 17
收藏 0
点赞 0
评论 0
评论列表
文章目录