def set_observe_embedding(self, example_observes, obs_emb, obs_emb_dim, obs_reshape=None):
self.obs_emb = obs_emb
self.obs_emb_dim = obs_emb_dim
if obs_emb == 'fc':
observe_layer = ObserveEmbeddingFC(Variable(example_observes), obs_emb_dim, dropout=self.dropout)
elif obs_emb == 'cnn1d2c':
observe_layer = ObserveEmbeddingCNN1D2C(Variable(example_observes), obs_emb_dim, dropout=self.dropout)
observe_layer.configure()
elif obs_emb == 'cnn2d6c':
observe_layer = ObserveEmbeddingCNN2D6C(Variable(example_observes), obs_emb_dim, obs_reshape, dropout=self.dropout)
observe_layer.configure()
elif obs_emb == 'cnn3d4c':
observe_layer = ObserveEmbeddingCNN3D4C(Variable(example_observes), obs_emb_dim, obs_reshape, dropout=self.dropout)
observe_layer.configure()
elif obs_emb == 'lstm':
observe_layer = ObserveEmbeddingLSTM(Variable(example_observes), obs_emb_dim, dropout=self.dropout)
else:
util.logger.log('set_observe_embedding: Unsupported observation embedding: ' + obs_emb)
self.observe_layer = observe_layer
评论列表
文章目录