def _build_output_graph(self, rep, t, dim_in, dim_out, do_out, FLAGS):
''' Construct output/regression layers '''
if FLAGS.split_output:
i0 = tf.to_int32(tf.where(t < 1)[:,0])
i1 = tf.to_int32(tf.where(t > 0)[:,0])
rep0 = tf.gather(rep, i0)
rep1 = tf.gather(rep, i1)
y0, weights_out0, weights_pred0 = self._build_output(rep0, dim_in, dim_out, do_out, FLAGS)
y1, weights_out1, weights_pred1 = self._build_output(rep1, dim_in, dim_out, do_out, FLAGS)
y = tf.dynamic_stitch([i0, i1], [y0, y1])
weights_out = weights_out0 + weights_out1
weights_pred = weights_pred0 + weights_pred1
else:
h_input = tf.concat(1,[rep, t])
y, weights_out, weights_pred = self._build_output(h_input, dim_in+1, dim_out, do_out, FLAGS)
return y, weights_out, weights_pred
评论列表
文章目录