evals.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:urnn 作者: Rand0mUsername 项目源码 文件源码
def train_urnn_for_timestep_idx(self, idx):
        print('Initializing and training URNNs for one timestep...')

        # CM

        tf.reset_default_graph()
        self.cm_urnn=TFRNN(
            name="cm_urnn",
            num_in=1,
            num_hidden=128,
            num_out=10,
            num_target=1,
            single_output=False,
            rnn_cell=URNNCell,
            activation_hidden=None, # modReLU
            activation_out=tf.identity,
            optimizer=tf.train.RMSPropOptimizer(learning_rate=glob_learning_rate, decay=glob_decay),
            loss_function=tf.nn.sparse_softmax_cross_entropy_with_logits)
        self.train_network(self.cm_urnn, self.cm_data[idx], 
                           self.cm_batch_size, self.cm_epochs)

        # AP

        tf.reset_default_graph()
        self.ap_urnn=TFRNN(
            name="ap_urnn",
            num_in=2,
            num_hidden=512,
            num_out=1,
            num_target=1,
            single_output=True,
            rnn_cell=URNNCell,
            activation_hidden=None, # modReLU
            activation_out=tf.identity,
            optimizer=tf.train.RMSPropOptimizer(learning_rate=glob_learning_rate, decay=glob_decay),
            loss_function=tf.squared_difference)
        self.train_network(self.ap_urnn, self.ap_data[idx], 
                           self.ap_batch_size, self.ap_epochs)

        print('Init and training URNNs for one timestep done.')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号