test_keras2_numeric.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:coremltools 作者: apple 项目源码 文件源码
def test_tiny_babi_rnn(self):
        vocab_size = 10
        embed_hidden_size = 8
        story_maxlen = 5
        query_maxlen = 5

        input_tensor_1 = Input(shape=(story_maxlen,))
        x1 = Embedding(vocab_size, embed_hidden_size)(input_tensor_1)
        x1 = Dropout(0.3)(x1)

        input_tensor_2 = Input(shape=(query_maxlen,))
        x2 = Embedding(vocab_size, embed_hidden_size)(input_tensor_2)
        x2 = Dropout(0.3)(x2)
        x2 = LSTM(embed_hidden_size, return_sequences=False)(x2)
        x2 = RepeatVector(story_maxlen)(x2)

        x3 = add([x1, x2])
        x3 = LSTM(embed_hidden_size, return_sequences=False)(x3)
        x3 = Dropout(0.3)(x3)
        x3 = Dense(vocab_size, activation='softmax')(x3)

        model = Model(inputs=[input_tensor_1,input_tensor_2], outputs=[x3])

        self._test_keras_model(model, one_dim_seq_flags=[True, True])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号