def understand_return_sequence():
"""?????? recurrent layer ?? return_sequences ??"""
model_1 = Sequential()
model_1.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
model_1.compile(loss='mean_squared_error', optimizer='sgd')
train_x = np.random.randn(100, 78, 256)
train_y = np.random.randn(100, 78, 256)
model_1.fit(train_x, train_y, verbose=0)
model_2 = Sequential()
model_2.add(GRU(input_dim=256, output_dim=256, return_sequences=False))
model_2.compile(loss='mean_squared_error', optimizer='sgd')
train_x = np.random.randn(100, 78, 256)
train_y = np.random.randn(100, 256)
model_2.fit(train_x, train_y, verbose=0)
inz = np.random.randn(100, 78, 256)
rez_1 = model_1.predict_proba(inz, verbose=0)
rez_2 = model_2.predict_proba(inz, verbose=0)
print()
print('=========== understand return_sequence =================')
print('Input shape is: {}'.format(inz.shape))
print('Output shape of model with `return_sequences=True`: {}'.format(rez_1.shape))
print('Output shape of model with `return_sequences=False`: {}'.format(rez_2.shape))
print('====================== end =============================')
评论列表
文章目录