def get_variable(name, shape, initializer=None, dtype=tf.float32, device=None):
"""
Helper to create a Variable stored on CPU memory.
Args:
name: name of the variable
shape: list of ints
initializer: initializer for Variable
dtype: data type, defaults to tf.float32
device: device to which the variable will be pinned
Returns:
Variable Tensor
"""
if device is None:
device = '/cpu:0'
if initializer is None:
with tf.device(device):
var = tf.get_variable(name, shape, dtype=dtype)
else:
with tf.device(device):
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
return var
评论列表
文章目录