hier_lstm.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:aes 作者: feidong1991 项目源码 文件源码
def build_attention_model(opts, vocab_size=0, maxnum=50, maxlen=50, embedd_dim=50, embedding_weights=None, verbose=False, init_mean_value=None):
    N = maxnum
    L = maxlen

    logger = get_logger('Build attention pooling model')
    logger.info("Model parameters: max_sentnum = %d, max_sentlen = %d, embedding dim = %s, lstm_units = %s, drop rate = %s, l2 = %s" % (N, L, embedd_dim,
        opts.lstm_units, opts.dropout, opts.l2_value))
    word_input = Input(shape=(N*L,), dtype='int32', name='word_input')
    x = Embedding(output_dim=embedd_dim, input_dim=vocab_size, input_length=N*L, weights=embedding_weights, name='x')(word_input)
    drop_x = Dropout(opts.dropout, name='drop_x')(x)

    resh_W = Reshape((N, L, embedd_dim), name='resh_W')(drop_x)

    z = TimeDistributed(LSTM(opts.lstm_units, return_sequences=True), name='z')(resh_W)
    avg_z = TimeDistributed(GlobalAveragePooling1D(), name='avg_z')(z)

    hz = LSTM(opts.lstm_units, return_sequences=True, name='hz')(avg_z)
    # avg_h = MeanOverTime(mask_zero=True, name='avg_h')(hz)
    # avg_hz = GlobalAveragePooling1D(name='avg_hz')(hz)
    attent_hz = Attention(name='attent_hz')(hz)
    y = Dense(output_dim=1, activation='sigmoid', name='output')(attent_hz)

    model = Model(input=word_input, output=y)
    if opts.init_bias and init_mean_value:
        logger.info("Initialise output layer bias with log(y_mean/1-y_mean)")
        bias_value = (np.log(init_mean_value) - np.log(1 - init_mean_value)).astype(K.floatx())
        model.layers[-1].b.set_value(bias_value)
    if verbose:
        model.summary()

    start_time = time.time()
    model.compile(loss='mse', optimizer='rmsprop')
    total_time = time.time() - start_time
    logger.info("Model compiled in %.4f s" % total_time)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号