def _provide_data(input_tensors, truncated_length, hparams):
"""Returns tensors for reading batches from provider."""
(spec, labels, label_weights, length, onsets, filename,
note_sequence) = input_tensors
length = tf.to_int32(length)
labels = tf.reshape(labels, (-1, constants.MIDI_PITCHES))
label_weights = tf.reshape(label_weights, (-1, constants.MIDI_PITCHES))
onsets = tf.reshape(onsets, (-1, constants.MIDI_PITCHES))
spec = tf.reshape(spec, (-1, hparams_frame_size(hparams)))
truncated_length = (tf.reduce_min([truncated_length, length])
if truncated_length else length)
# Pad or slice specs and labels tensors to have the same lengths,
# truncating after truncated_length.
spec_delta = tf.shape(spec)[0] - truncated_length
spec = tf.case(
[(spec_delta < 0,
lambda: tf.pad(spec, tf.stack([(0, -spec_delta), (0, 0)]))),
(spec_delta > 0, lambda: spec[0:-spec_delta])],
default=lambda: spec)
labels_delta = tf.shape(labels)[0] - truncated_length
labels = tf.case(
[(labels_delta < 0,
lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))),
(labels_delta > 0, lambda: labels[0:-labels_delta])],
default=lambda: labels)
label_weights = tf.case(
[(labels_delta < 0,
lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta), (0, 0)]))
), (labels_delta > 0, lambda: label_weights[0:-labels_delta])],
default=lambda: label_weights)
onsets = tf.case(
[(labels_delta < 0,
lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))),
(labels_delta > 0, lambda: onsets[0:-labels_delta])],
default=lambda: onsets)
truncated_note_sequence = truncate_note_sequence_op(
note_sequence, truncated_length, hparams)
batch_tensors = {
'spec': tf.reshape(
spec, (truncated_length, hparams_frame_size(hparams), 1)),
'labels': tf.reshape(labels, (truncated_length, constants.MIDI_PITCHES)),
'label_weights': tf.reshape(
label_weights, (truncated_length, constants.MIDI_PITCHES)),
'lengths': truncated_length,
'onsets': tf.reshape(onsets, (truncated_length, constants.MIDI_PITCHES)),
'filenames': filename,
'note_sequences': truncated_note_sequence,
}
return batch_tensors
评论列表
文章目录