understand.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:soph 作者: Linusp 项目源码 文件源码
def try_variable_length_train():
    """????????

    ?????????? train_x ? train_y ? dtype ? object ???
    ?? shape ???? (100,) ?????????
    """
    model = Sequential()
    model.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model.compile(loss='mean_squared_error', optimizer='sgd')

    train_x = []
    train_y = []
    for i in range(100):
        seq_length = np.random.randint(78, 87 + 1)
        sequence = []
        for _ in range(seq_length):
            sequence.append([np.random.randn() for _ in range(256)])

        train_x.append(np.array(sequence))
        train_y.append(np.array(sequence))

    train_x = np.array(train_x)
    train_y = np.array(train_y)

    model.fit(np.array(train_x), np.array(train_y))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号