mask_generator.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Neural-Photo-Editor 作者: ajbrock 项目源码 文件源码
def __init__(self, input_size, hidden_sizes, l, random_seed=1234):
        self._random_seed = random_seed
        self._mrng = MRG_RandomStreams(seed=random_seed)
        self._rng = RandomStreams(seed=random_seed)

        self._hidden_sizes = hidden_sizes
        self._input_size = input_size
        self._l = l

        self.ordering = theano.shared(value=np.arange(input_size, dtype=theano.config.floatX), name='ordering', borrow=False)

        # Initial layer connectivity
        self.layers_connectivity = [theano.shared(value=(self.ordering + 1).eval(), name='layer_connectivity_input', borrow=False)]
        for i in range(len(self._hidden_sizes)):
            self.layers_connectivity += [theano.shared(value=np.zeros((self._hidden_sizes[i]), dtype=theano.config.floatX), name='layer_connectivity_hidden{0}'.format(i), borrow=False)]
        self.layers_connectivity += [self.ordering]

        ## Theano functions
        new_ordering = self._rng.shuffle_row_elements(self.ordering)
        self.shuffle_ordering = theano.function(name='shuffle_ordering',
                                                inputs=[],
                                                updates=[(self.ordering, new_ordering), (self.layers_connectivity[0], new_ordering + 1)])

        self.layers_connectivity_updates = []
        for i in range(len(self._hidden_sizes)):
            self.layers_connectivity_updates += [self._get_hidden_layer_connectivity(i)]
        # self.layers_connectivity_updates = [self._get_hidden_layer_connectivity(i) for i in range(len(self._hidden_sizes))]  # WTF THIS DO NOT WORK
        self.sample_connectivity = theano.function(name='sample_connectivity',
                                                   inputs=[],
                                                   updates=[(self.layers_connectivity[i+1], self.layers_connectivity_updates[i]) for i in range(len(self._hidden_sizes))])

        # Save random initial state
        self._initial_mrng_rstate = copy.deepcopy(self._mrng.rstate)
        self._initial_mrng_state_updates = [state_update[0].get_value() for state_update in self._mrng.state_updates]

        # Ensuring valid initial connectivity
        self.sample_connectivity()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号