def _make_rnn_cell(spec: RNNSpec) -> Callable[[], RNNCell]:
"""Return the graph template for creating RNN cells."""
if spec.cell_type == "GRU":
def cell():
return OrthoGRUCell(spec.size)
elif spec.cell_type == "LSTM":
def cell():
return tf.contrib.rnn.LSTMCell(spec.size)
else:
raise ValueError("Unknown RNN cell: {}".format(spec.cell_type))
return cell
# pylint: disable=too-many-instance-attributes
评论列表
文章目录