understand.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:soph 作者: Linusp 项目源码 文件源码
def understand_return_sequence():
    """?????? recurrent layer ?? return_sequences ??"""
    model_1 = Sequential()
    model_1.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model_1.compile(loss='mean_squared_error', optimizer='sgd')
    train_x = np.random.randn(100, 78, 256)
    train_y = np.random.randn(100, 78, 256)
    model_1.fit(train_x, train_y, verbose=0)

    model_2 = Sequential()
    model_2.add(GRU(input_dim=256, output_dim=256, return_sequences=False))
    model_2.compile(loss='mean_squared_error', optimizer='sgd')
    train_x = np.random.randn(100, 78, 256)
    train_y = np.random.randn(100, 256)
    model_2.fit(train_x, train_y, verbose=0)

    inz = np.random.randn(100, 78, 256)
    rez_1 = model_1.predict_proba(inz, verbose=0)
    rez_2 = model_2.predict_proba(inz, verbose=0)

    print()
    print('=========== understand return_sequence =================')
    print('Input shape is: {}'.format(inz.shape))
    print('Output shape of model with `return_sequences=True`: {}'.format(rez_1.shape))
    print('Output shape of model with `return_sequences=False`: {}'.format(rez_2.shape))
    print('====================== end =============================')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号