def ctc_model(a_backend):
""" Returns a model which uses the CTC loss function.
"""
if a_backend.get_name() == 'pytorch':
pytest.xfail('Backend "{}" does not use a CTC loss function.'
.format(a_backend.get_name()))
output_timesteps = 10
vocab_size = 4
return model_with_containers(
backend=a_backend,
containers=[
{'input' : {'shape' : [output_timesteps, 2]}, 'name' : 'TEST_input'},
{'recurrent' : {'size' : vocab_size+1, 'sequence' : True}},
{'activation' : 'softmax', 'name' : 'TEST_output'}
]
)
###############################################################################
评论列表
文章目录