Module note_seq.midi_io_test

Test to ensure correct midi input and output.

Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test to ensure correct midi input and output."""
import collections
import os.path
import tempfile

from absl.testing import absltest
import mido
from note_seq import constants
from note_seq import midi_io
from note_seq import testing_lib
from note_seq.protobuf import music_pb2
import pretty_midi

# self.midi_simple_filename contains a c-major scale of 8 quarter notes each
# with a sustain of .95 of the entire note. Here are the first two notes dumped
# using mididump.py:
#   midi.NoteOnEvent(tick=0, channel=0, data=[60, 100]),
#   midi.NoteOnEvent(tick=209, channel=0, data=[60, 0]),
#   midi.NoteOnEvent(tick=11, channel=0, data=[62, 100]),
#   midi.NoteOnEvent(tick=209, channel=0, data=[62, 0]),
_SIMPLE_MIDI_FILE_VELO = 100
_SIMPLE_MIDI_FILE_NUM_NOTES = 8
_SIMPLE_MIDI_FILE_SUSTAIN = .95

# self.midi_complex_filename contains many instruments including percussion as
# well as control change and pitch bend events.

# self.midi_is_drum_filename contains 41 tracks, two of which are on channel 9.

# self.midi_event_order_filename contains notes ordered
# non-monotonically by pitch.  Here are relevent events as printed by
# mididump.py:
#   midi.NoteOnEvent(tick=0, channel=0, data=[1, 100]),
#   midi.NoteOnEvent(tick=0, channel=0, data=[3, 100]),
#   midi.NoteOnEvent(tick=0, channel=0, data=[2, 100]),
#   midi.NoteOnEvent(tick=4400, channel=0, data=[3, 0]),
#   midi.NoteOnEvent(tick=0, channel=0, data=[1, 0]),
#   midi.NoteOnEvent(tick=0, channel=0, data=[2, 0]),


class MidiIoTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.midi_simple_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example.mid')
    self.midi_complex_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example_complex.mid')
    self.midi_is_drum_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example_is_drum.mid')
    self.midi_event_order_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example_event_order.mid')

  def CheckPrettyMidiAndSequence(self, midi, sequence_proto):
    """Compares PrettyMIDI object against a sequence proto.

    Args:
      midi: A pretty_midi.PrettyMIDI object.
      sequence_proto: A NoteSequence proto.
    """
    # Test time signature changes.
    self.assertEqual(len(midi.time_signature_changes),
                     len(sequence_proto.time_signatures))
    for midi_time, sequence_time in zip(midi.time_signature_changes,
                                        sequence_proto.time_signatures):
      self.assertEqual(midi_time.numerator, sequence_time.numerator)
      self.assertEqual(midi_time.denominator, sequence_time.denominator)
      self.assertAlmostEqual(midi_time.time, sequence_time.time)

    # Test key signature changes.
    self.assertEqual(len(midi.key_signature_changes),
                     len(sequence_proto.key_signatures))
    for midi_key, sequence_key in zip(midi.key_signature_changes,
                                      sequence_proto.key_signatures):
      self.assertEqual(midi_key.key_number % 12, sequence_key.key)
      self.assertEqual(midi_key.key_number // 12, sequence_key.mode)
      self.assertAlmostEqual(midi_key.time, sequence_key.time)

    # Test tempos.
    midi_times, midi_qpms = midi.get_tempo_changes()
    self.assertEqual(len(midi_times),
                     len(sequence_proto.tempos))
    self.assertEqual(len(midi_qpms),
                     len(sequence_proto.tempos))
    for midi_time, midi_qpm, sequence_tempo in zip(
        midi_times, midi_qpms, sequence_proto.tempos):
      self.assertAlmostEqual(midi_qpm, sequence_tempo.qpm)
      self.assertAlmostEqual(midi_time, sequence_tempo.time)

    # Test instruments.
    seq_instruments = collections.defaultdict(
        lambda: collections.defaultdict(list))
    for seq_note in sequence_proto.notes:
      seq_instruments[
          (seq_note.instrument, seq_note.program, seq_note.is_drum)][
              'notes'].append(seq_note)
    for seq_bend in sequence_proto.pitch_bends:
      seq_instruments[
          (seq_bend.instrument, seq_bend.program, seq_bend.is_drum)][
              'bends'].append(seq_bend)
    for seq_control in sequence_proto.control_changes:
      seq_instruments[
          (seq_control.instrument, seq_control.program, seq_control.is_drum)][
              'controls'].append(seq_control)

    sorted_seq_instrument_keys = sorted(seq_instruments.keys())

    if seq_instruments:
      self.assertEqual(len(midi.instruments), len(seq_instruments))
    else:
      self.assertLen(midi.instruments, 1)
      self.assertEmpty(midi.instruments[0].notes)
      self.assertEmpty(midi.instruments[0].pitch_bends)

    for midi_instrument, seq_instrument_key in zip(
        midi.instruments, sorted_seq_instrument_keys):

      seq_instrument_notes = seq_instruments[seq_instrument_key]['notes']

      self.assertEqual(len(midi_instrument.notes), len(seq_instrument_notes))
      for midi_note, sequence_note in zip(midi_instrument.notes,
                                          seq_instrument_notes):
        self.assertEqual(midi_note.pitch, sequence_note.pitch)
        self.assertEqual(midi_note.velocity, sequence_note.velocity)
        self.assertAlmostEqual(midi_note.start, sequence_note.start_time)
        self.assertAlmostEqual(midi_note.end, sequence_note.end_time)

      seq_instrument_pitch_bends = seq_instruments[seq_instrument_key]['bends']
      self.assertEqual(len(midi_instrument.pitch_bends),
                       len(seq_instrument_pitch_bends))
      for midi_pitch_bend, sequence_pitch_bend in zip(
          midi_instrument.pitch_bends,
          seq_instrument_pitch_bends):
        self.assertEqual(midi_pitch_bend.pitch, sequence_pitch_bend.bend)
        self.assertAlmostEqual(midi_pitch_bend.time, sequence_pitch_bend.time)

  def CheckMidiToSequence(self, filename):
    """Test the translation from PrettyMIDI to Sequence proto."""
    source_midi = pretty_midi.PrettyMIDI(filename)
    sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    self.CheckPrettyMidiAndSequence(source_midi, sequence_proto)

  def CheckSequenceToPrettyMidi(self, filename):
    """Test the translation from Sequence proto to PrettyMIDI."""
    source_midi = pretty_midi.PrettyMIDI(filename)
    sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    translated_midi = midi_io.sequence_proto_to_pretty_midi(sequence_proto)
    self.CheckPrettyMidiAndSequence(translated_midi, sequence_proto)

  def CheckReadWriteMidi(self, filename):
    """Test writing to a MIDI file and comparing it to the original Sequence."""

    # TODO(deck): The input MIDI file is opened in pretty-midi and
    # re-written to a temp file, sanitizing the MIDI data (reordering
    # note ons, etc). Issue 85 in the pretty-midi GitHub
    # (http://github.com/craffel/pretty-midi/issues/85) requests that
    # this sanitization be available outside of the context of a file
    # write. If that is implemented, this rewrite code should be
    # modified or deleted.

    # When writing to the temp file, use the file object itself instead of
    # file.name to avoid the permission error on Windows.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
      original_midi = pretty_midi.PrettyMIDI(filename)
      original_midi.write(rewrite_file)  # Use file object
      # Back the file position to top to reload the rewrite_file
      rewrite_file.seek(0)
      source_midi = pretty_midi.PrettyMIDI(rewrite_file)  # Use file object
      sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

    # Translate the NoteSequence to MIDI and write to a file.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
      # Read it back in and compare to source.
      created_midi = pretty_midi.PrettyMIDI(temp_file)  # Use file object

    self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)

  def testSimplePrettyMidiToSequence(self):
    self.CheckMidiToSequence(self.midi_simple_filename)

  def testSimpleSequenceToPrettyMidi(self):
    self.CheckSequenceToPrettyMidi(self.midi_simple_filename)

  def testSimpleSequenceToPrettyMidi_DefaultTicksAndTempo(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    stripped_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    del stripped_sequence_proto.tempos[:]
    stripped_sequence_proto.ClearField('ticks_per_quarter')

    expected_sequence_proto = music_pb2.NoteSequence()
    expected_sequence_proto.CopyFrom(stripped_sequence_proto)
    expected_sequence_proto.tempos.add(
        qpm=constants.DEFAULT_QUARTERS_PER_MINUTE)
    expected_sequence_proto.ticks_per_quarter = constants.STANDARD_PPQ

    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        stripped_sequence_proto)

    self.CheckPrettyMidiAndSequence(translated_midi, expected_sequence_proto)

  def testSimpleSequenceToPrettyMidi_MultipleTempos(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
    multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto)

    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  def testSimpleSequenceToPrettyMidi_FirstTempoNotAtZero(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    del multi_tempo_sequence_proto.tempos[:]
    multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
    multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto)

    # Translating to MIDI adds an implicit DEFAULT_QUARTERS_PER_MINUTE tempo
    # at time 0, so recreate the list with that in place.
    del multi_tempo_sequence_proto.tempos[:]
    multi_tempo_sequence_proto.tempos.add(
        time=0.0, qpm=constants.DEFAULT_QUARTERS_PER_MINUTE)
    multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
    multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  def testSimpleSequenceToPrettyMidi_DropEventsAfterLastNote(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    # Add a final tempo long after the last note.
    multi_tempo_sequence_proto.tempos.add(time=600.0, qpm=120)

    # Translate without dropping.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto)
    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

    # Translate dropping anything after the last note.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto, drop_events_n_seconds_after_last_note=0)
    # The added tempo should have been dropped.
    del multi_tempo_sequence_proto.tempos[-1]
    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

    # Add a final tempo 15 seconds after the last note.
    last_note_time = max([n.end_time for n in multi_tempo_sequence_proto.notes])
    multi_tempo_sequence_proto.tempos.add(time=last_note_time + 15, qpm=120)
    # Translate dropping anything 30 seconds after the last note, which should
    # preserve the added tempo.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto, drop_events_n_seconds_after_last_note=30)
    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  def testEmptySequenceToPrettyMidi_DropEventsAfterLastNote(self):
    source_sequence = music_pb2.NoteSequence()

    # Translate without dropping.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence)
    self.assertLen(translated_midi.instruments, 1)
    self.assertEmpty(translated_midi.instruments[0].notes)

    # Translate dropping anything after 30 seconds.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence, drop_events_n_seconds_after_last_note=30)
    self.assertLen(translated_midi.instruments, 1)
    self.assertEmpty(translated_midi.instruments[0].notes)

  def testNonEmptySequenceWithNoNotesToPrettyMidi_DropEventsAfterLastNote(self):
    source_sequence = music_pb2.NoteSequence()
    source_sequence.tempos.add(time=0, qpm=120)
    source_sequence.tempos.add(time=10, qpm=160)
    source_sequence.tempos.add(time=40, qpm=240)

    # Translate without dropping.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence)
    self.CheckPrettyMidiAndSequence(translated_midi, source_sequence)

    # Translate dropping anything after 30 seconds.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence, drop_events_n_seconds_after_last_note=30)
    del source_sequence.tempos[-1]
    self.CheckPrettyMidiAndSequence(translated_midi, source_sequence)

  def testSimpleReadWriteMidi(self):
    self.CheckReadWriteMidi(self.midi_simple_filename)

  def testComplexPrettyMidiToSequence(self):
    self.CheckMidiToSequence(self.midi_complex_filename)

  def testComplexSequenceToPrettyMidi(self):
    self.CheckSequenceToPrettyMidi(self.midi_complex_filename)

  def testIsDrumDetection(self):
    """Verify that is_drum instruments are properly tracked.

    self.midi_is_drum_filename is a MIDI file containing two tracks
    set to channel 9 (is_drum == True). Each contains one NoteOn. This
    test is designed to catch a bug where the second track would lose
    is_drum, remapping the drum track to an instrument track.
    """
    sequence_proto = midi_io.midi_file_to_sequence_proto(
        self.midi_is_drum_filename)
    with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
      midi_data1 = mido.MidiFile(filename=self.midi_is_drum_filename)
      # Use the file object when writing to the tempfile
      # to avoid permission error.
      midi_data2 = mido.MidiFile(file=temp_file)

    # Count number of channel 9 Note Ons.
    channel_counts = [0, 0]
    for index, midi_data in enumerate([midi_data1, midi_data2]):
      for event in midi_data:
        if (event.type == 'note_on' and
            event.velocity > 0 and event.channel == 9):
          channel_counts[index] += 1
    self.assertEqual(channel_counts, [2, 2])

  def testInstrumentInfo_NoteSequenceToPrettyMidi(self):
    source_sequence = music_pb2.NoteSequence()
    source_sequence.notes.add(
        pitch=60, start_time=0.0, end_time=0.5, velocity=80, instrument=0)
    source_sequence.notes.add(
        pitch=60, start_time=0.5, end_time=1.0, velocity=80, instrument=1)
    instrument_info1 = source_sequence.instrument_infos.add()
    instrument_info1.name = 'inst_0'
    instrument_info1.instrument = 0
    instrument_info2 = source_sequence.instrument_infos.add()
    instrument_info2.name = 'inst_1'
    instrument_info2.instrument = 1
    translated_midi = midi_io.sequence_proto_to_pretty_midi(source_sequence)
    translated_sequence = midi_io.midi_to_note_sequence(translated_midi)

    self.assertEqual(
        len(source_sequence.instrument_infos),
        len(translated_sequence.instrument_infos))
    self.assertEqual(source_sequence.instrument_infos[0].name,
                     translated_sequence.instrument_infos[0].name)
    self.assertEqual(source_sequence.instrument_infos[1].name,
                     translated_sequence.instrument_infos[1].name)

  def testComplexReadWriteMidi(self):
    self.CheckReadWriteMidi(self.midi_complex_filename)

  def testEventOrdering(self):
    self.CheckReadWriteMidi(self.midi_event_order_filename)


if __name__ == '__main__':
  absltest.main()

Classes

class MidiIoTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class MidiIoTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.midi_simple_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example.mid')
    self.midi_complex_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example_complex.mid')
    self.midi_is_drum_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example_is_drum.mid')
    self.midi_event_order_filename = os.path.join(
        testing_lib.get_testdata_dir(), 'example_event_order.mid')

  def CheckPrettyMidiAndSequence(self, midi, sequence_proto):
    """Compares PrettyMIDI object against a sequence proto.

    Args:
      midi: A pretty_midi.PrettyMIDI object.
      sequence_proto: A NoteSequence proto.
    """
    # Test time signature changes.
    self.assertEqual(len(midi.time_signature_changes),
                     len(sequence_proto.time_signatures))
    for midi_time, sequence_time in zip(midi.time_signature_changes,
                                        sequence_proto.time_signatures):
      self.assertEqual(midi_time.numerator, sequence_time.numerator)
      self.assertEqual(midi_time.denominator, sequence_time.denominator)
      self.assertAlmostEqual(midi_time.time, sequence_time.time)

    # Test key signature changes.
    self.assertEqual(len(midi.key_signature_changes),
                     len(sequence_proto.key_signatures))
    for midi_key, sequence_key in zip(midi.key_signature_changes,
                                      sequence_proto.key_signatures):
      self.assertEqual(midi_key.key_number % 12, sequence_key.key)
      self.assertEqual(midi_key.key_number // 12, sequence_key.mode)
      self.assertAlmostEqual(midi_key.time, sequence_key.time)

    # Test tempos.
    midi_times, midi_qpms = midi.get_tempo_changes()
    self.assertEqual(len(midi_times),
                     len(sequence_proto.tempos))
    self.assertEqual(len(midi_qpms),
                     len(sequence_proto.tempos))
    for midi_time, midi_qpm, sequence_tempo in zip(
        midi_times, midi_qpms, sequence_proto.tempos):
      self.assertAlmostEqual(midi_qpm, sequence_tempo.qpm)
      self.assertAlmostEqual(midi_time, sequence_tempo.time)

    # Test instruments.
    seq_instruments = collections.defaultdict(
        lambda: collections.defaultdict(list))
    for seq_note in sequence_proto.notes:
      seq_instruments[
          (seq_note.instrument, seq_note.program, seq_note.is_drum)][
              'notes'].append(seq_note)
    for seq_bend in sequence_proto.pitch_bends:
      seq_instruments[
          (seq_bend.instrument, seq_bend.program, seq_bend.is_drum)][
              'bends'].append(seq_bend)
    for seq_control in sequence_proto.control_changes:
      seq_instruments[
          (seq_control.instrument, seq_control.program, seq_control.is_drum)][
              'controls'].append(seq_control)

    sorted_seq_instrument_keys = sorted(seq_instruments.keys())

    if seq_instruments:
      self.assertEqual(len(midi.instruments), len(seq_instruments))
    else:
      self.assertLen(midi.instruments, 1)
      self.assertEmpty(midi.instruments[0].notes)
      self.assertEmpty(midi.instruments[0].pitch_bends)

    for midi_instrument, seq_instrument_key in zip(
        midi.instruments, sorted_seq_instrument_keys):

      seq_instrument_notes = seq_instruments[seq_instrument_key]['notes']

      self.assertEqual(len(midi_instrument.notes), len(seq_instrument_notes))
      for midi_note, sequence_note in zip(midi_instrument.notes,
                                          seq_instrument_notes):
        self.assertEqual(midi_note.pitch, sequence_note.pitch)
        self.assertEqual(midi_note.velocity, sequence_note.velocity)
        self.assertAlmostEqual(midi_note.start, sequence_note.start_time)
        self.assertAlmostEqual(midi_note.end, sequence_note.end_time)

      seq_instrument_pitch_bends = seq_instruments[seq_instrument_key]['bends']
      self.assertEqual(len(midi_instrument.pitch_bends),
                       len(seq_instrument_pitch_bends))
      for midi_pitch_bend, sequence_pitch_bend in zip(
          midi_instrument.pitch_bends,
          seq_instrument_pitch_bends):
        self.assertEqual(midi_pitch_bend.pitch, sequence_pitch_bend.bend)
        self.assertAlmostEqual(midi_pitch_bend.time, sequence_pitch_bend.time)

  def CheckMidiToSequence(self, filename):
    """Test the translation from PrettyMIDI to Sequence proto."""
    source_midi = pretty_midi.PrettyMIDI(filename)
    sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    self.CheckPrettyMidiAndSequence(source_midi, sequence_proto)

  def CheckSequenceToPrettyMidi(self, filename):
    """Test the translation from Sequence proto to PrettyMIDI."""
    source_midi = pretty_midi.PrettyMIDI(filename)
    sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    translated_midi = midi_io.sequence_proto_to_pretty_midi(sequence_proto)
    self.CheckPrettyMidiAndSequence(translated_midi, sequence_proto)

  def CheckReadWriteMidi(self, filename):
    """Test writing to a MIDI file and comparing it to the original Sequence."""

    # TODO(deck): The input MIDI file is opened in pretty-midi and
    # re-written to a temp file, sanitizing the MIDI data (reordering
    # note ons, etc). Issue 85 in the pretty-midi GitHub
    # (http://github.com/craffel/pretty-midi/issues/85) requests that
    # this sanitization be available outside of the context of a file
    # write. If that is implemented, this rewrite code should be
    # modified or deleted.

    # When writing to the temp file, use the file object itself instead of
    # file.name to avoid the permission error on Windows.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
      original_midi = pretty_midi.PrettyMIDI(filename)
      original_midi.write(rewrite_file)  # Use file object
      # Back the file position to top to reload the rewrite_file
      rewrite_file.seek(0)
      source_midi = pretty_midi.PrettyMIDI(rewrite_file)  # Use file object
      sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

    # Translate the NoteSequence to MIDI and write to a file.
    with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
      # Read it back in and compare to source.
      created_midi = pretty_midi.PrettyMIDI(temp_file)  # Use file object

    self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)

  def testSimplePrettyMidiToSequence(self):
    self.CheckMidiToSequence(self.midi_simple_filename)

  def testSimpleSequenceToPrettyMidi(self):
    self.CheckSequenceToPrettyMidi(self.midi_simple_filename)

  def testSimpleSequenceToPrettyMidi_DefaultTicksAndTempo(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    stripped_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    del stripped_sequence_proto.tempos[:]
    stripped_sequence_proto.ClearField('ticks_per_quarter')

    expected_sequence_proto = music_pb2.NoteSequence()
    expected_sequence_proto.CopyFrom(stripped_sequence_proto)
    expected_sequence_proto.tempos.add(
        qpm=constants.DEFAULT_QUARTERS_PER_MINUTE)
    expected_sequence_proto.ticks_per_quarter = constants.STANDARD_PPQ

    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        stripped_sequence_proto)

    self.CheckPrettyMidiAndSequence(translated_midi, expected_sequence_proto)

  def testSimpleSequenceToPrettyMidi_MultipleTempos(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
    multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto)

    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  def testSimpleSequenceToPrettyMidi_FirstTempoNotAtZero(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    del multi_tempo_sequence_proto.tempos[:]
    multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
    multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto)

    # Translating to MIDI adds an implicit DEFAULT_QUARTERS_PER_MINUTE tempo
    # at time 0, so recreate the list with that in place.
    del multi_tempo_sequence_proto.tempos[:]
    multi_tempo_sequence_proto.tempos.add(
        time=0.0, qpm=constants.DEFAULT_QUARTERS_PER_MINUTE)
    multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
    multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  def testSimpleSequenceToPrettyMidi_DropEventsAfterLastNote(self):
    source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
    multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
    # Add a final tempo long after the last note.
    multi_tempo_sequence_proto.tempos.add(time=600.0, qpm=120)

    # Translate without dropping.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto)
    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

    # Translate dropping anything after the last note.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto, drop_events_n_seconds_after_last_note=0)
    # The added tempo should have been dropped.
    del multi_tempo_sequence_proto.tempos[-1]
    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

    # Add a final tempo 15 seconds after the last note.
    last_note_time = max([n.end_time for n in multi_tempo_sequence_proto.notes])
    multi_tempo_sequence_proto.tempos.add(time=last_note_time + 15, qpm=120)
    # Translate dropping anything 30 seconds after the last note, which should
    # preserve the added tempo.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        multi_tempo_sequence_proto, drop_events_n_seconds_after_last_note=30)
    self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  def testEmptySequenceToPrettyMidi_DropEventsAfterLastNote(self):
    source_sequence = music_pb2.NoteSequence()

    # Translate without dropping.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence)
    self.assertLen(translated_midi.instruments, 1)
    self.assertEmpty(translated_midi.instruments[0].notes)

    # Translate dropping anything after 30 seconds.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence, drop_events_n_seconds_after_last_note=30)
    self.assertLen(translated_midi.instruments, 1)
    self.assertEmpty(translated_midi.instruments[0].notes)

  def testNonEmptySequenceWithNoNotesToPrettyMidi_DropEventsAfterLastNote(self):
    source_sequence = music_pb2.NoteSequence()
    source_sequence.tempos.add(time=0, qpm=120)
    source_sequence.tempos.add(time=10, qpm=160)
    source_sequence.tempos.add(time=40, qpm=240)

    # Translate without dropping.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence)
    self.CheckPrettyMidiAndSequence(translated_midi, source_sequence)

    # Translate dropping anything after 30 seconds.
    translated_midi = midi_io.sequence_proto_to_pretty_midi(
        source_sequence, drop_events_n_seconds_after_last_note=30)
    del source_sequence.tempos[-1]
    self.CheckPrettyMidiAndSequence(translated_midi, source_sequence)

  def testSimpleReadWriteMidi(self):
    self.CheckReadWriteMidi(self.midi_simple_filename)

  def testComplexPrettyMidiToSequence(self):
    self.CheckMidiToSequence(self.midi_complex_filename)

  def testComplexSequenceToPrettyMidi(self):
    self.CheckSequenceToPrettyMidi(self.midi_complex_filename)

  def testIsDrumDetection(self):
    """Verify that is_drum instruments are properly tracked.

    self.midi_is_drum_filename is a MIDI file containing two tracks
    set to channel 9 (is_drum == True). Each contains one NoteOn. This
    test is designed to catch a bug where the second track would lose
    is_drum, remapping the drum track to an instrument track.
    """
    sequence_proto = midi_io.midi_file_to_sequence_proto(
        self.midi_is_drum_filename)
    with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
      midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
      midi_data1 = mido.MidiFile(filename=self.midi_is_drum_filename)
      # Use the file object when writing to the tempfile
      # to avoid permission error.
      midi_data2 = mido.MidiFile(file=temp_file)

    # Count number of channel 9 Note Ons.
    channel_counts = [0, 0]
    for index, midi_data in enumerate([midi_data1, midi_data2]):
      for event in midi_data:
        if (event.type == 'note_on' and
            event.velocity > 0 and event.channel == 9):
          channel_counts[index] += 1
    self.assertEqual(channel_counts, [2, 2])

  def testInstrumentInfo_NoteSequenceToPrettyMidi(self):
    source_sequence = music_pb2.NoteSequence()
    source_sequence.notes.add(
        pitch=60, start_time=0.0, end_time=0.5, velocity=80, instrument=0)
    source_sequence.notes.add(
        pitch=60, start_time=0.5, end_time=1.0, velocity=80, instrument=1)
    instrument_info1 = source_sequence.instrument_infos.add()
    instrument_info1.name = 'inst_0'
    instrument_info1.instrument = 0
    instrument_info2 = source_sequence.instrument_infos.add()
    instrument_info2.name = 'inst_1'
    instrument_info2.instrument = 1
    translated_midi = midi_io.sequence_proto_to_pretty_midi(source_sequence)
    translated_sequence = midi_io.midi_to_note_sequence(translated_midi)

    self.assertEqual(
        len(source_sequence.instrument_infos),
        len(translated_sequence.instrument_infos))
    self.assertEqual(source_sequence.instrument_infos[0].name,
                     translated_sequence.instrument_infos[0].name)
    self.assertEqual(source_sequence.instrument_infos[1].name,
                     translated_sequence.instrument_infos[1].name)

  def testComplexReadWriteMidi(self):
    self.CheckReadWriteMidi(self.midi_complex_filename)

  def testEventOrdering(self):
    self.CheckReadWriteMidi(self.midi_event_order_filename)

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def CheckMidiToSequence(self, filename)

Test the translation from PrettyMIDI to Sequence proto.

Expand source code
def CheckMidiToSequence(self, filename):
  """Test the translation from PrettyMIDI to Sequence proto."""
  source_midi = pretty_midi.PrettyMIDI(filename)
  sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
  self.CheckPrettyMidiAndSequence(source_midi, sequence_proto)
def CheckPrettyMidiAndSequence(self, midi, sequence_proto)

Compares PrettyMIDI object against a sequence proto.

Args

midi
A pretty_midi.PrettyMIDI object.
sequence_proto
A NoteSequence proto.
Expand source code
def CheckPrettyMidiAndSequence(self, midi, sequence_proto):
  """Compares PrettyMIDI object against a sequence proto.

  Args:
    midi: A pretty_midi.PrettyMIDI object.
    sequence_proto: A NoteSequence proto.
  """
  # Test time signature changes.
  self.assertEqual(len(midi.time_signature_changes),
                   len(sequence_proto.time_signatures))
  for midi_time, sequence_time in zip(midi.time_signature_changes,
                                      sequence_proto.time_signatures):
    self.assertEqual(midi_time.numerator, sequence_time.numerator)
    self.assertEqual(midi_time.denominator, sequence_time.denominator)
    self.assertAlmostEqual(midi_time.time, sequence_time.time)

  # Test key signature changes.
  self.assertEqual(len(midi.key_signature_changes),
                   len(sequence_proto.key_signatures))
  for midi_key, sequence_key in zip(midi.key_signature_changes,
                                    sequence_proto.key_signatures):
    self.assertEqual(midi_key.key_number % 12, sequence_key.key)
    self.assertEqual(midi_key.key_number // 12, sequence_key.mode)
    self.assertAlmostEqual(midi_key.time, sequence_key.time)

  # Test tempos.
  midi_times, midi_qpms = midi.get_tempo_changes()
  self.assertEqual(len(midi_times),
                   len(sequence_proto.tempos))
  self.assertEqual(len(midi_qpms),
                   len(sequence_proto.tempos))
  for midi_time, midi_qpm, sequence_tempo in zip(
      midi_times, midi_qpms, sequence_proto.tempos):
    self.assertAlmostEqual(midi_qpm, sequence_tempo.qpm)
    self.assertAlmostEqual(midi_time, sequence_tempo.time)

  # Test instruments.
  seq_instruments = collections.defaultdict(
      lambda: collections.defaultdict(list))
  for seq_note in sequence_proto.notes:
    seq_instruments[
        (seq_note.instrument, seq_note.program, seq_note.is_drum)][
            'notes'].append(seq_note)
  for seq_bend in sequence_proto.pitch_bends:
    seq_instruments[
        (seq_bend.instrument, seq_bend.program, seq_bend.is_drum)][
            'bends'].append(seq_bend)
  for seq_control in sequence_proto.control_changes:
    seq_instruments[
        (seq_control.instrument, seq_control.program, seq_control.is_drum)][
            'controls'].append(seq_control)

  sorted_seq_instrument_keys = sorted(seq_instruments.keys())

  if seq_instruments:
    self.assertEqual(len(midi.instruments), len(seq_instruments))
  else:
    self.assertLen(midi.instruments, 1)
    self.assertEmpty(midi.instruments[0].notes)
    self.assertEmpty(midi.instruments[0].pitch_bends)

  for midi_instrument, seq_instrument_key in zip(
      midi.instruments, sorted_seq_instrument_keys):

    seq_instrument_notes = seq_instruments[seq_instrument_key]['notes']

    self.assertEqual(len(midi_instrument.notes), len(seq_instrument_notes))
    for midi_note, sequence_note in zip(midi_instrument.notes,
                                        seq_instrument_notes):
      self.assertEqual(midi_note.pitch, sequence_note.pitch)
      self.assertEqual(midi_note.velocity, sequence_note.velocity)
      self.assertAlmostEqual(midi_note.start, sequence_note.start_time)
      self.assertAlmostEqual(midi_note.end, sequence_note.end_time)

    seq_instrument_pitch_bends = seq_instruments[seq_instrument_key]['bends']
    self.assertEqual(len(midi_instrument.pitch_bends),
                     len(seq_instrument_pitch_bends))
    for midi_pitch_bend, sequence_pitch_bend in zip(
        midi_instrument.pitch_bends,
        seq_instrument_pitch_bends):
      self.assertEqual(midi_pitch_bend.pitch, sequence_pitch_bend.bend)
      self.assertAlmostEqual(midi_pitch_bend.time, sequence_pitch_bend.time)
def CheckReadWriteMidi(self, filename)

Test writing to a MIDI file and comparing it to the original Sequence.

Expand source code
def CheckReadWriteMidi(self, filename):
  """Test writing to a MIDI file and comparing it to the original Sequence."""

  # TODO(deck): The input MIDI file is opened in pretty-midi and
  # re-written to a temp file, sanitizing the MIDI data (reordering
  # note ons, etc). Issue 85 in the pretty-midi GitHub
  # (http://github.com/craffel/pretty-midi/issues/85) requests that
  # this sanitization be available outside of the context of a file
  # write. If that is implemented, this rewrite code should be
  # modified or deleted.

  # When writing to the temp file, use the file object itself instead of
  # file.name to avoid the permission error on Windows.
  with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as rewrite_file:
    original_midi = pretty_midi.PrettyMIDI(filename)
    original_midi.write(rewrite_file)  # Use file object
    # Back the file position to top to reload the rewrite_file
    rewrite_file.seek(0)
    source_midi = pretty_midi.PrettyMIDI(rewrite_file)  # Use file object
    sequence_proto = midi_io.midi_to_sequence_proto(source_midi)

  # Translate the NoteSequence to MIDI and write to a file.
  with tempfile.NamedTemporaryFile(prefix='MidiIoTest') as temp_file:
    midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
    # Read it back in and compare to source.
    created_midi = pretty_midi.PrettyMIDI(temp_file)  # Use file object

  self.CheckPrettyMidiAndSequence(created_midi, sequence_proto)
def CheckSequenceToPrettyMidi(self, filename)

Test the translation from Sequence proto to PrettyMIDI.

Expand source code
def CheckSequenceToPrettyMidi(self, filename):
  """Test the translation from Sequence proto to PrettyMIDI."""
  source_midi = pretty_midi.PrettyMIDI(filename)
  sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
  translated_midi = midi_io.sequence_proto_to_pretty_midi(sequence_proto)
  self.CheckPrettyMidiAndSequence(translated_midi, sequence_proto)
def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.midi_simple_filename = os.path.join(
      testing_lib.get_testdata_dir(), 'example.mid')
  self.midi_complex_filename = os.path.join(
      testing_lib.get_testdata_dir(), 'example_complex.mid')
  self.midi_is_drum_filename = os.path.join(
      testing_lib.get_testdata_dir(), 'example_is_drum.mid')
  self.midi_event_order_filename = os.path.join(
      testing_lib.get_testdata_dir(), 'example_event_order.mid')
def testComplexPrettyMidiToSequence(self)
Expand source code
def testComplexPrettyMidiToSequence(self):
  self.CheckMidiToSequence(self.midi_complex_filename)
def testComplexReadWriteMidi(self)
Expand source code
def testComplexReadWriteMidi(self):
  self.CheckReadWriteMidi(self.midi_complex_filename)
def testComplexSequenceToPrettyMidi(self)
Expand source code
def testComplexSequenceToPrettyMidi(self):
  self.CheckSequenceToPrettyMidi(self.midi_complex_filename)
def testEmptySequenceToPrettyMidi_DropEventsAfterLastNote(self)
Expand source code
def testEmptySequenceToPrettyMidi_DropEventsAfterLastNote(self):
  source_sequence = music_pb2.NoteSequence()

  # Translate without dropping.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      source_sequence)
  self.assertLen(translated_midi.instruments, 1)
  self.assertEmpty(translated_midi.instruments[0].notes)

  # Translate dropping anything after 30 seconds.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      source_sequence, drop_events_n_seconds_after_last_note=30)
  self.assertLen(translated_midi.instruments, 1)
  self.assertEmpty(translated_midi.instruments[0].notes)
def testEventOrdering(self)
Expand source code
def testEventOrdering(self):
  self.CheckReadWriteMidi(self.midi_event_order_filename)
def testInstrumentInfo_NoteSequenceToPrettyMidi(self)
Expand source code
def testInstrumentInfo_NoteSequenceToPrettyMidi(self):
  source_sequence = music_pb2.NoteSequence()
  source_sequence.notes.add(
      pitch=60, start_time=0.0, end_time=0.5, velocity=80, instrument=0)
  source_sequence.notes.add(
      pitch=60, start_time=0.5, end_time=1.0, velocity=80, instrument=1)
  instrument_info1 = source_sequence.instrument_infos.add()
  instrument_info1.name = 'inst_0'
  instrument_info1.instrument = 0
  instrument_info2 = source_sequence.instrument_infos.add()
  instrument_info2.name = 'inst_1'
  instrument_info2.instrument = 1
  translated_midi = midi_io.sequence_proto_to_pretty_midi(source_sequence)
  translated_sequence = midi_io.midi_to_note_sequence(translated_midi)

  self.assertEqual(
      len(source_sequence.instrument_infos),
      len(translated_sequence.instrument_infos))
  self.assertEqual(source_sequence.instrument_infos[0].name,
                   translated_sequence.instrument_infos[0].name)
  self.assertEqual(source_sequence.instrument_infos[1].name,
                   translated_sequence.instrument_infos[1].name)
def testIsDrumDetection(self)

Verify that is_drum instruments are properly tracked.

self.midi_is_drum_filename is a MIDI file containing two tracks set to channel 9 (is_drum == True). Each contains one NoteOn. This test is designed to catch a bug where the second track would lose is_drum, remapping the drum track to an instrument track.

Expand source code
def testIsDrumDetection(self):
  """Verify that is_drum instruments are properly tracked.

  self.midi_is_drum_filename is a MIDI file containing two tracks
  set to channel 9 (is_drum == True). Each contains one NoteOn. This
  test is designed to catch a bug where the second track would lose
  is_drum, remapping the drum track to an instrument track.
  """
  sequence_proto = midi_io.midi_file_to_sequence_proto(
      self.midi_is_drum_filename)
  with tempfile.NamedTemporaryFile(prefix='MidiDrumTest') as temp_file:
    midi_io.sequence_proto_to_midi_file(sequence_proto, temp_file.name)
    midi_data1 = mido.MidiFile(filename=self.midi_is_drum_filename)
    # Use the file object when writing to the tempfile
    # to avoid permission error.
    midi_data2 = mido.MidiFile(file=temp_file)

  # Count number of channel 9 Note Ons.
  channel_counts = [0, 0]
  for index, midi_data in enumerate([midi_data1, midi_data2]):
    for event in midi_data:
      if (event.type == 'note_on' and
          event.velocity > 0 and event.channel == 9):
        channel_counts[index] += 1
  self.assertEqual(channel_counts, [2, 2])
def testNonEmptySequenceWithNoNotesToPrettyMidi_DropEventsAfterLastNote(self)
Expand source code
def testNonEmptySequenceWithNoNotesToPrettyMidi_DropEventsAfterLastNote(self):
  source_sequence = music_pb2.NoteSequence()
  source_sequence.tempos.add(time=0, qpm=120)
  source_sequence.tempos.add(time=10, qpm=160)
  source_sequence.tempos.add(time=40, qpm=240)

  # Translate without dropping.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      source_sequence)
  self.CheckPrettyMidiAndSequence(translated_midi, source_sequence)

  # Translate dropping anything after 30 seconds.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      source_sequence, drop_events_n_seconds_after_last_note=30)
  del source_sequence.tempos[-1]
  self.CheckPrettyMidiAndSequence(translated_midi, source_sequence)
def testSimplePrettyMidiToSequence(self)
Expand source code
def testSimplePrettyMidiToSequence(self):
  self.CheckMidiToSequence(self.midi_simple_filename)
def testSimpleReadWriteMidi(self)
Expand source code
def testSimpleReadWriteMidi(self):
  self.CheckReadWriteMidi(self.midi_simple_filename)
def testSimpleSequenceToPrettyMidi(self)
Expand source code
def testSimpleSequenceToPrettyMidi(self):
  self.CheckSequenceToPrettyMidi(self.midi_simple_filename)
def testSimpleSequenceToPrettyMidi_DefaultTicksAndTempo(self)
Expand source code
def testSimpleSequenceToPrettyMidi_DefaultTicksAndTempo(self):
  source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
  stripped_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
  del stripped_sequence_proto.tempos[:]
  stripped_sequence_proto.ClearField('ticks_per_quarter')

  expected_sequence_proto = music_pb2.NoteSequence()
  expected_sequence_proto.CopyFrom(stripped_sequence_proto)
  expected_sequence_proto.tempos.add(
      qpm=constants.DEFAULT_QUARTERS_PER_MINUTE)
  expected_sequence_proto.ticks_per_quarter = constants.STANDARD_PPQ

  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      stripped_sequence_proto)

  self.CheckPrettyMidiAndSequence(translated_midi, expected_sequence_proto)
def testSimpleSequenceToPrettyMidi_DropEventsAfterLastNote(self)
Expand source code
def testSimpleSequenceToPrettyMidi_DropEventsAfterLastNote(self):
  source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
  multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
  # Add a final tempo long after the last note.
  multi_tempo_sequence_proto.tempos.add(time=600.0, qpm=120)

  # Translate without dropping.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      multi_tempo_sequence_proto)
  self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  # Translate dropping anything after the last note.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      multi_tempo_sequence_proto, drop_events_n_seconds_after_last_note=0)
  # The added tempo should have been dropped.
  del multi_tempo_sequence_proto.tempos[-1]
  self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)

  # Add a final tempo 15 seconds after the last note.
  last_note_time = max([n.end_time for n in multi_tempo_sequence_proto.notes])
  multi_tempo_sequence_proto.tempos.add(time=last_note_time + 15, qpm=120)
  # Translate dropping anything 30 seconds after the last note, which should
  # preserve the added tempo.
  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      multi_tempo_sequence_proto, drop_events_n_seconds_after_last_note=30)
  self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)
def testSimpleSequenceToPrettyMidi_FirstTempoNotAtZero(self)
Expand source code
def testSimpleSequenceToPrettyMidi_FirstTempoNotAtZero(self):
  source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
  multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
  del multi_tempo_sequence_proto.tempos[:]
  multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
  multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      multi_tempo_sequence_proto)

  # Translating to MIDI adds an implicit DEFAULT_QUARTERS_PER_MINUTE tempo
  # at time 0, so recreate the list with that in place.
  del multi_tempo_sequence_proto.tempos[:]
  multi_tempo_sequence_proto.tempos.add(
      time=0.0, qpm=constants.DEFAULT_QUARTERS_PER_MINUTE)
  multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
  multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

  self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)
def testSimpleSequenceToPrettyMidi_MultipleTempos(self)
Expand source code
def testSimpleSequenceToPrettyMidi_MultipleTempos(self):
  source_midi = pretty_midi.PrettyMIDI(self.midi_simple_filename)
  multi_tempo_sequence_proto = midi_io.midi_to_sequence_proto(source_midi)
  multi_tempo_sequence_proto.tempos.add(time=1.0, qpm=60)
  multi_tempo_sequence_proto.tempos.add(time=2.0, qpm=120)

  translated_midi = midi_io.sequence_proto_to_pretty_midi(
      multi_tempo_sequence_proto)

  self.CheckPrettyMidiAndSequence(translated_midi, multi_tempo_sequence_proto)