terpret_tf_runtime.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def make_batch_consistent(args, set_batch_size=None):
    """
    args[i] should be either [arg_dim] or [batch_size x arg_dim]
    if rank(args[i]) == 1 then tile to [batch_size x arg_dim]
    """
    if set_batch_size is None:
        # infer the batch_size from arg shapes
        batched_args = filter(lambda x : x.get_shape().ndims > 1, args)
        #batched_args = filter(lambda x : x.get_shape()[0].value is None, args)
        if len(batched_args) == 0:
            batch_size = 1
            is_batched = False
        else:
            # TODO: tf.assert_equal() to check that all batch sizes are consistent?
            batch_size = tf.shape(batched_args[0])[0]
            is_batched = True
    else: 
        batch_size = set_batch_size
        is_batched = True

    # tile any rank-1 args to a consistent batch_size
    tmp_args = []
    for arg in args:
        arg_rank = arg.get_shape().ndims
        assert_rank_1_or_2(arg_rank)
        if arg_rank == 1:
            tmp_args.append(tf.tile(tf.expand_dims(arg,0), [batch_size,1]))
        else:
            tmp_args.append(arg)
    args = tmp_args
    return args, is_batched
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号