def interpret_dict( a_dict, model,n_times=1, on_logits=True):
'''
pass either a do_dict or a cond_dict.
The rules for converting arguments to numpy arrays to pass
to tensorflow are identical
'''
if a_dict is None:
return {}
elif len(a_dict)==0:
return {}
if n_times>1:
token=tf.placeholder_with_default(2.22)
a_dict[token]=-2.22
p_a_dict=take_product(a_dict)
##Need divisible batch_size for most models
if len(p_a_dict)>0:
L=len(p_a_dict.values()[0])
else:
L=0
print("L is " + str(L))
print(p_a_dict)
##Check compatability batch_size and L
if L>=model.batch_size:
if not L % model.batch_size == 0:
raise ValueError('a_dict must be dividable by batch_size\
but instead product of inputs was of length',L)
elif model.batch_size % L == 0:
p_a_dict = {key:np.repeat(value,model.batch_size/L,axis=0) for key,value in p_a_dict.items()}
else:
raise ValueError('No. of intervened values must divide batch_size.')
return p_a_dict
评论列表
文章目录