def lstm_cell(num_units, num_layers):
"""Constructs a `MultiRNNCell` with num_layers `BasicLSTMCell`s.
Args:
num_units: The number of units in the `RNNCell`.
num_layers: The number of layers in the RNN.
Returns:
An intiialized `MultiRNNCell`.
"""
return rnn_cell.MultiRNNCell([
rnn_cell.BasicLSTMCell(
num_units=num_units, state_is_tuple=True) for _ in range(num_layers)
])
state_saving_rnn_estimator.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录