rnn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:LSTM-TensorSpark 作者: EmanuelOverflow 项目源码 文件源码
def fit_next(self, data, s, last_state=True, train=True):  # set choose optimizer
        with tf.name_scope('optimizer'):
            input_data_T = tf.transpose([data], name="input_data_T")

            if not self.ht:
                # Init h_t
                self.ht = tf.Variable(tf.random_normal([self.shape[0], 1]), trainable=False, name="ht_%d" % self.node_id)
                # Init C_t
                self.Ct = tf.Variable(tf.ones([self.shape[0], 1]), trainable=False, name="Ct_%d" % self.node_id)

                # Init layers variables
                self.ft = tf.Variable(tf.ones([self.shape[0], 1]), trainable=False, name="ft_%d" % self.node_id)
                self.it = tf.Variable(tf.ones([self.shape[0], 1]), trainable=False, name="it_%d" % self.node_id)
                self.Cta = tf.Variable(tf.ones([self.shape[0], 1]), trainable=False, name="Cta_%d" % self.node_id)

                s.run(tf.initialize_variables([self.ht, self.Ct, self.ft, self.it, self.Cta]))

            with tf.name_scope('train_layer'):
                self.train_layer(input_data_T, s)
                if train:
                    self.state.append((self.ht, self.Ct)) # store the state of each step
                    ret = self.state[-1] if last_state else self.state
                else:
                    ret = (self.ht, self.Ct)
                    self.restore_state()
        return ret
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号