testing.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:LSTM-Generative-and-Discriminative 作者: mattweidman 项目源码 文件源码
def test_generator_training():

    # parameters
    file_name = "animals.txt"
    genr_hidden_size = 10
    disr_hidden_size = 3
    num_epochs_d = 20
    num_epochs_g = 20
    lr = 1
    alpha = 0.9
    batch_size = 100

    # load data
    char_list = dataloader.get_char_list(file_name)
    X_actual = dataloader.load_data(file_name)
    num_examples = X_actual.shape[0]
    seq_len = X_actual.shape[1]

    # generate
    genr_input = np.random.randn(num_examples, len(char_list))
    genr = Generator(genr_hidden_size, char_list)
    X_generated = genr.generate_tensor(seq_len, num_examples, genr_input)

    # train discriminator
    disr = Discriminator(len(char_list), disr_hidden_size)
    disr.train_RMS(X_actual, X_generated, num_epochs_d, lr, alpha, batch_size)

    # evaluate discriminator
    accuracy = disr.accuracy(X_actual, X_generated)
    print("accuracy: ", accuracy)

    # train generator
    genr.train_RMS(genr_input, seq_len, disr, num_epochs_g, 1, lr, alpha,
        batch_size, print_progress=True)

    # evaluate discriminator again
    X_generated = genr.generate_tensor(seq_len, num_examples, genr_input)
    accuracy = disr.accuracy(X_actual, X_generated)
    print("accuracy: ", accuracy)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号