def understand_variable_length_handle():
"""????????? recurrent layer ??????"""
model = Sequential()
model.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
model.compile(loss='mean_squared_error', optimizer='sgd')
train_x = np.random.randn(100, 78, 256)
train_y = np.random.randn(100, 78, 256)
model.fit(train_x, train_y, verbose=0)
inz_1 = np.random.randn(1, 78, 256)
rez_1 = model.predict_proba(inz_1, verbose=0)
inz_2 = np.random.randn(1, 87, 256)
rez_2 = model.predict_proba(inz_2, verbose=0)
print()
print('=========== understand variable length =================')
print('With `return_sequence=True`')
print('Input shape is: {}, output shae is {}'.format(inz_1.shape, rez_1.shape))
print('Input shape is: {}, output shae is {}'.format(inz_2.shape, rez_2.shape))
print('====================== end =============================')
评论列表
文章目录