def tensors_to_item(self, keys_to_tensors):
tensor = keys_to_tensors[self._tensor_key]
shape = self._shape
if self._shape_keys:
shape_dims = []
for k in self._shape_keys:
shape_dim = keys_to_tensors[k]
if isinstance(shape_dim, tf.SparseTensor):
shape_dim = tf.sparse_tensor_to_dense(shape_dim)
shape_dims.append(shape_dim)
shape = tf.reshape(tf.stack(shape_dims), [-1])
if isinstance(tensor, tf.SparseTensor):
if shape is not None:
tensor = tf.sparse_reshape(tensor, shape)
tensor = tf.sparse_tensor_to_dense(
tensor, self._default_value)
else:
if shape is not None:
tensor = tf.reshape(tensor, shape)
return tensor
评论列表
文章目录