testing.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:LSTM-Generative-and-Discriminative 作者: mattweidman 项目源码 文件源码
def test_discriminator():

    # parameters
    file_name = "animals.txt"
    genr_hidden_size = 10
    disr_hidden_size = 11
    num_epochs = 20
    lr = 1
    alpha = 0.9
    batch_size = 100

    # load data
    char_list = dataloader.get_char_list(file_name)
    X_actual = dataloader.load_data(file_name)
    num_examples = X_actual.shape[0]
    seq_len = X_actual.shape[1]

    # generate
    genr = Generator(genr_hidden_size, char_list)
    X_generated = genr.generate_tensor(seq_len, num_examples)

    # train discriminator
    disr = Discriminator(len(char_list), disr_hidden_size)
    disr.train_RMS(X_actual, X_generated, num_epochs, lr, alpha, batch_size,
        print_progress=True)

    # print discriminator output
    outp = disr.discriminate(np.concatenate((X_actual, X_generated), axis=0))
    print(outp)

    # evaluate discriminator
    accuracy = disr.accuracy(X_actual, X_generated)
    print("accuracy: ", accuracy)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号