utils.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def matmul(a, b, benchmark=True, name='mul'):  # TODO maybe put inside dot
    """
    Interface function for matmul that works also with sparse tensors

    :param a:
    :param b:
    :param benchmark:
    :param name:
    :return:
    """
    a_is_sparse = isinstance(a, tf.SparseTensor)
    with tf.name_scope(name):
        if a_is_sparse:
            mul = wsr(tf.matmul(tf.sparse_tensor_to_dense(a, default_value=0.), b))
            if benchmark:
                mul_ops = [wsr(tf.sparse_tensor_dense_matmul(a, b)),
                           mul,  # others ?
                           # wsr(tf.nn.embedding_lookup_sparse())  # I couldn't figure out how this works......
                           ]

                def _avg_exe_times(op, repetitions):
                    from time import time
                    ex_times = []
                    for _ in range(repetitions):
                        st = time()
                        op.eval()
                        ex_times.append(time() - st)
                    return np.mean(ex_times[1:]), np.max(ex_times), np.min(ex_times)

                with tf.Session(config=CONFIG_GPU_GROWTH).as_default():
                    tf.global_variables_initializer().run()  # TODO here should only initialize necessary variable
                    # (downstream in the computation graph)

                    statistics = {op: _avg_exe_times(op, repetitions=4) for op in mul_ops}

                [print(k, v) for k, v in statistics.items()]

                mul = sorted(statistics.items(), key=lambda v: v[1][0])[0][0]  # returns best one w.r.t. avg exe time

                print(mul, 'selected')

        else:
            mul = wsr(tf.matmul(a, b))
    return mul


# Define a context manager to suppress stdout and stderr.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号