def get_nets(name): if name=='LSTM': return recurrent.LSTM elif name=='GRU': return recurrent.GRU else: return recurrent.SimpleRNN