model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:keras_npi 作者: mokemokechicken 项目源码 文件源码
def build(self):
        enc_size = self.size_of_env_observation()
        argument_size = IntegerArguments.size_of_arguments
        input_enc = InputLayer(batch_input_shape=(self.batch_size, enc_size), name='input_enc')
        input_arg = InputLayer(batch_input_shape=(self.batch_size, argument_size), name='input_arg')
        input_prg = Embedding(input_dim=PROGRAM_VEC_SIZE, output_dim=PROGRAM_KEY_VEC_SIZE, input_length=1,
                              batch_input_shape=(self.batch_size, 1))

        f_enc = Sequential(name='f_enc')
        f_enc.add(Merge([input_enc, input_arg], mode='concat'))
        f_enc.add(MaxoutDense(128, nb_feature=4))
        self.f_enc = f_enc

        program_embedding = Sequential(name='program_embedding')
        program_embedding.add(input_prg)

        f_enc_convert = Sequential(name='f_enc_convert')
        f_enc_convert.add(f_enc)
        f_enc_convert.add(RepeatVector(1))

        f_lstm = Sequential(name='f_lstm')
        f_lstm.add(Merge([f_enc_convert, program_embedding], mode='concat'))
        f_lstm.add(LSTM(256, return_sequences=False, stateful=True, W_regularizer=l2(0.0000001)))
        f_lstm.add(Activation('relu', name='relu_lstm_1'))
        f_lstm.add(RepeatVector(1))
        f_lstm.add(LSTM(256, return_sequences=False, stateful=True, W_regularizer=l2(0.0000001)))
        f_lstm.add(Activation('relu', name='relu_lstm_2'))
        # plot(f_lstm, to_file='f_lstm.png', show_shapes=True)

        f_end = Sequential(name='f_end')
        f_end.add(f_lstm)
        f_end.add(Dense(1, W_regularizer=l2(0.001)))
        f_end.add(Activation('sigmoid', name='sigmoid_end'))

        f_prog = Sequential(name='f_prog')
        f_prog.add(f_lstm)
        f_prog.add(Dense(PROGRAM_KEY_VEC_SIZE, activation="relu"))
        f_prog.add(Dense(PROGRAM_VEC_SIZE, W_regularizer=l2(0.0001)))
        f_prog.add(Activation('softmax', name='softmax_prog'))
        # plot(f_prog, to_file='f_prog.png', show_shapes=True)

        f_args = []
        for ai in range(1, IntegerArguments.max_arg_num+1):
            f_arg = Sequential(name='f_arg%s' % ai)
            f_arg.add(f_lstm)
            f_arg.add(Dense(IntegerArguments.depth, W_regularizer=l2(0.0001)))
            f_arg.add(Activation('softmax', name='softmax_arg%s' % ai))
            f_args.append(f_arg)
        # plot(f_arg, to_file='f_arg.png', show_shapes=True)

        self.model = Model([input_enc.input, input_arg.input, input_prg.input],
                           [f_end.output, f_prog.output] + [fa.output for fa in f_args],
                           name="npi")
        self.compile_model()
        plot(self.model, to_file='model.png', show_shapes=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号