def forward(self, x, prevState = None ):
# dimensions
if len(x.size()) == 2: x = x.unsqueeze(0)
batch = x.size(0)
steps = x.size(1)
if prevState == None: prevState = {}
hs = {}
cs = {}
for t in range(steps):
# xt
xt = x[:,t,:]
# prev h and pre c
hp = hs[t-1] or prevState.h or torch.zeros()
a = 0
recurrentLSTMNetwork.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录