nn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pyprob 作者: probprog 项目源码 文件源码
def set_observe_embedding(self, example_observes, obs_emb, obs_emb_dim, obs_reshape=None):
        self.obs_emb = obs_emb
        self.obs_emb_dim = obs_emb_dim
        if obs_emb == 'fc':
            observe_layer = ObserveEmbeddingFC(Variable(example_observes), obs_emb_dim, dropout=self.dropout)
        elif obs_emb == 'cnn1d2c':
            observe_layer = ObserveEmbeddingCNN1D2C(Variable(example_observes), obs_emb_dim, dropout=self.dropout)
            observe_layer.configure()
        elif obs_emb == 'cnn2d6c':
            observe_layer = ObserveEmbeddingCNN2D6C(Variable(example_observes), obs_emb_dim, obs_reshape, dropout=self.dropout)
            observe_layer.configure()
        elif obs_emb == 'cnn3d4c':
            observe_layer = ObserveEmbeddingCNN3D4C(Variable(example_observes), obs_emb_dim, obs_reshape, dropout=self.dropout)
            observe_layer.configure()
        elif obs_emb == 'lstm':
            observe_layer = ObserveEmbeddingLSTM(Variable(example_observes), obs_emb_dim, dropout=self.dropout)
        else:
            util.logger.log('set_observe_embedding: Unsupported observation embedding: ' + obs_emb)

        self.observe_layer = observe_layer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号