PhasedLSTMCell.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:PLSTM 作者: Enny1991 项目源码 文件源码
def _get_concat_variable(name, shape, dtype, num_shards):
    """Get a sharded variable concatenated into one tensor."""
    sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
    if len(sharded_variable) == 1:
        return sharded_variable[0]

    concat_name = name + "/concat"
    concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
    for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
        if value.name == concat_full_name:
            return value

    concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
    ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
                          concat_variable)
    return concat_variable
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号