def _make_tree_cell(self, i):
if self._cell_type == "lstm":
cell = TreeLSTM(self.output_size)
elif self._cell_type in ("gru", "basic-tanh"):
raise NotImplementedError("GRU/basic-tanh tree cells not implemented yet")
else:
raise ValueError("Invalid RNN Cell type")
cell = TreeDropoutWrapper(cell, output_keep_prob=self._dropout, seed=8 + 33 * i)
return cell
tree_encoder.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录