def multi_target_data(name_list, shape, dtype=tf.float32):
""" Multi Target Data.
Create and concatenate multiple placeholders. To be used when a regression
layer uses targets from different sources.
Arguments:
name_list: list of `str`. The names of the target placeholders.
shape: list of `int`. The shape of the placeholders.
dtype: `tf.type`, Placeholder data type (optional). Default: float32.
Return:
A `Tensor` of the concatenated placeholders.
"""
placeholders = []
for i in range(len(name_list)):
with tf.name_scope(name_list[i]):
p = tf.placeholder(shape=shape, dtype=dtype, name='Y')
if p not in tf.get_collection(tf.GraphKeys.TARGETS):
tf.add_to_collection(tf.GraphKeys.TARGETS, p)
placeholders.append(p)
return tf.concat(placeholders, axis=0)
评论列表
文章目录