graph_utils.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lemontree 作者: khshim 项目源码 文件源码
def get_inputs_of_variables(variables):
    """
    This function returns required inputs for the (tensor variable) variable.
    The order of the inputs are toposorted.

    Parameters
    ----------
    variable: list
        a list of (tensor variable) to see.
        usally this is a theano function output list. (loss, accuracy, etc.)

    Returns
    -------
    list
        a list of required inputs to compute the variable.
    """
    # assert
    assert isinstance(variables, list), 'Variables should be a list of tensor variable(s).'
    assert all(isinstance(var, T.TensorVariable) for var in variables), 'All input should be a tensor variable.'

    # do
    variable_inputs = [var for var in graph.inputs(variables) if isinstance(var, T.TensorVariable)]
    variable_inputs = list(OrderedDict.fromkeys(variable_inputs))  # preserve order and make to list
    print('Required inputs are:', variable_inputs)
    return variable_inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号