def extract_params_as_shared_arrays(model):
"""
converts params to shared arrays
"""
# can get in the form of list -> shared + policy + value
shared_arrays = []
weights_dict = model.get_all_weights()
weight_list = []
for k,v in weights_dict.items():
weight_list += v
for weights in weight_list:
shared_arrays.append(mp.RawArray('f', weights.ravel()))
return shared_arrays
评论列表
文章目录