understand.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:soph 作者: Linusp 项目源码 文件源码
def understand_variable_length_handle():
    """????????? recurrent layer ??????"""
    model = Sequential()
    model.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model.compile(loss='mean_squared_error', optimizer='sgd')
    train_x = np.random.randn(100, 78, 256)
    train_y = np.random.randn(100, 78, 256)
    model.fit(train_x, train_y, verbose=0)

    inz_1 = np.random.randn(1, 78, 256)
    rez_1 = model.predict_proba(inz_1, verbose=0)

    inz_2 = np.random.randn(1, 87, 256)
    rez_2 = model.predict_proba(inz_2, verbose=0)

    print()
    print('=========== understand variable length =================')
    print('With `return_sequence=True`')
    print('Input shape is: {}, output shae is {}'.format(inz_1.shape, rez_1.shape))
    print('Input shape is: {}, output shae is {}'.format(inz_2.shape, rez_2.shape))
    print('====================== end =============================')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号