data.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:magenta 作者: tensorflow 项目源码 文件源码
def _provide_data(input_tensors, truncated_length, hparams):
  """Returns tensors for reading batches from provider."""
  (spec, labels, label_weights, length, onsets, filename,
   note_sequence) = input_tensors

  length = tf.to_int32(length)
  labels = tf.reshape(labels, (-1, constants.MIDI_PITCHES))
  label_weights = tf.reshape(label_weights, (-1, constants.MIDI_PITCHES))
  onsets = tf.reshape(onsets, (-1, constants.MIDI_PITCHES))
  spec = tf.reshape(spec, (-1, hparams_frame_size(hparams)))

  truncated_length = (tf.reduce_min([truncated_length, length])
                      if truncated_length else length)

  # Pad or slice specs and labels tensors to have the same lengths,
  # truncating after truncated_length.
  spec_delta = tf.shape(spec)[0] - truncated_length
  spec = tf.case(
      [(spec_delta < 0,
        lambda: tf.pad(spec, tf.stack([(0, -spec_delta), (0, 0)]))),
       (spec_delta > 0, lambda: spec[0:-spec_delta])],
      default=lambda: spec)
  labels_delta = tf.shape(labels)[0] - truncated_length
  labels = tf.case(
      [(labels_delta < 0,
        lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))),
       (labels_delta > 0, lambda: labels[0:-labels_delta])],
      default=lambda: labels)
  label_weights = tf.case(
      [(labels_delta < 0,
        lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta), (0, 0)]))
       ), (labels_delta > 0, lambda: label_weights[0:-labels_delta])],
      default=lambda: label_weights)
  onsets = tf.case(
      [(labels_delta < 0,
        lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))),
       (labels_delta > 0, lambda: onsets[0:-labels_delta])],
      default=lambda: onsets)

  truncated_note_sequence = truncate_note_sequence_op(
      note_sequence, truncated_length, hparams)

  batch_tensors = {
      'spec': tf.reshape(
          spec, (truncated_length, hparams_frame_size(hparams), 1)),
      'labels': tf.reshape(labels, (truncated_length, constants.MIDI_PITCHES)),
      'label_weights': tf.reshape(
          label_weights, (truncated_length, constants.MIDI_PITCHES)),
      'lengths': truncated_length,
      'onsets': tf.reshape(onsets, (truncated_length, constants.MIDI_PITCHES)),
      'filenames': filename,
      'note_sequences': truncated_note_sequence,
  }

  return batch_tensors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号