cudnn_rnn_ops_benchmark.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def benchmarkTfRNNLSTMTraining(self):
    test_configs = self._GetTestConfig()
    for config_name, config in test_configs.items():
      num_layers = config["num_layers"]
      num_units = config["num_units"]
      batch_size = config["batch_size"]
      seq_length = config["seq_length"]

      with ops.Graph().as_default(), ops.device("/gpu:0"):
        inputs = seq_length * [
            array_ops.zeros([batch_size, num_units], dtypes.float32)
        ]
        initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)

        cell = core_rnn_cell_impl.LSTMCell(
            num_units=num_units, initializer=initializer, state_is_tuple=True)
        multi_cell = core_rnn_cell_impl.MultiRNNCell(
            [cell() for _ in range(num_layers)])
        outputs, final_state = core_rnn.static_rnn(
            multi_cell, inputs, dtype=dtypes.float32)
        trainable_variables = ops.get_collection(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        gradients = gradients_impl.gradients([outputs, final_state],
                                             trainable_variables)
        training_op = control_flow_ops.group(*gradients)
        self._BenchmarkOp(training_op, "tf_rnn_lstm %s %s" %
                          (config_name, self._GetConfigDesc(config)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号