models.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:self-driving-truck 作者: aleju 项目源码 文件源码
def __init__(self):
        super(SuccessorPredictor, self).__init__()

        def identity(v):
            return lambda x: x
        bn2d = nn.InstanceNorm2d
        bn1d = identity

        self.input_size = 9
        self.hidden_size = 512
        self.nb_layers = 1

        self.hidden_fc1 = nn.Linear(512, self.nb_layers*2*self.hidden_size)
        self.hidden_fc1_bn = bn1d(self.nb_layers*2*self.hidden_size)

        self.rnn = nn.LSTM(self.input_size, self.hidden_size, self.nb_layers, dropout=0.1, batch_first=False)

        self.fc1 = nn.Linear(self.hidden_size, 512)

        init_weights(self)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号