isan_cell.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorflow-isan-rnn 作者: philipperemy 项目源码 文件源码
def call(self, step_inputs, state, scope=None, initialization='gaussian'):
        """
        Make one step of ISAN transition.

        Args:
          step_inputs: one-hot encoded inputs, shape bs x n
          state: previous hidden state, shape bs x d
          scope: current scope
          initialization: how to initialize the transition matrices:
            orthogonal: usually speeds up training, orthogonalize Gaussian matrices
            gaussian: sample gaussian matrices with a sensible scale
        """
        d = self._num_units
        n = step_inputs.shape[1].value

        if initialization == 'orthogonal':
            wx_ndd_init = np.zeros((n, d * d), dtype=np.float32)
            for i in range(n):
                wx_ndd_init[i, :] = orth(np.random.randn(d, d)).astype(np.float32).ravel()
            wx_ndd_initializer = tf.constant_initializer(wx_ndd_init)
        elif initialization == 'gaussian':
            wx_ndd_initializer = tf.random_normal_initializer(stddev=1.0 / np.sqrt(d))
        else:
            raise Exception('Unknown init type: %s' % initialization)

        wx_ndd = tf.get_variable('Wx', shape=[n, d * d],
                                 initializer=wx_ndd_initializer)
        bx_nd = tf.get_variable('bx', shape=[n, d],
                                initializer=tf.zeros_initializer())

        # Multiplication with a 1-hot is just row selection.
        # As of Jan '17 this is faster than doing gather.
        Wx_bdd = tf.reshape(tf.matmul(step_inputs, wx_ndd), [-1, d, d])
        bx_bd = tf.reshape(tf.matmul(step_inputs, bx_nd), [-1, 1, d])

        # Reshape the state so that matmul multiplies different matrices
        # for each batch element.
        single_state = tf.reshape(state, [-1, 1, d])
        new_state = tf.reshape(tf.matmul(single_state, Wx_bdd) + bx_bd, [-1, d])
        return new_state, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号