layer_test.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:anago 作者: Hironsan 项目源码 文件源码
def test_chain_crf(self):
        vocab_size = 20
        n_classes = 11
        model = Sequential()
        model.add(Embedding(vocab_size, n_classes))
        layer = ChainCRF()
        model.add(layer)
        model.compile(loss=layer.loss, optimizer='sgd')

        # Train first mini batch
        batch_size, maxlen = 2, 2
        x = np.random.randint(1, vocab_size, size=(batch_size, maxlen))
        y = np.random.randint(n_classes, size=(batch_size, maxlen))
        y = np.eye(n_classes)[y]
        model.train_on_batch(x, y)

        print(x)
        print(y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号