Module note_seq.testing_lib

Testing support code.

Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Testing support code."""

import os

from absl.testing import absltest

from note_seq import encoder_decoder
from note_seq.protobuf import compare
from note_seq.protobuf import music_pb2

from google.protobuf import descriptor_pool
from google.protobuf import text_format

# Shortcut to text annotation types.
BEAT = music_pb2.NoteSequence.TextAnnotation.BEAT
CHORD_SYMBOL = music_pb2.NoteSequence.TextAnnotation.CHORD_SYMBOL


def add_track_to_sequence(note_sequence,
                          instrument,
                          notes,
                          is_drum=False,
                          program=0):
  """Adds instrument track to NoteSequence."""
  for pitch, velocity, start_time, end_time in notes:
    note = note_sequence.notes.add()
    note.pitch = pitch
    note.velocity = velocity
    note.start_time = start_time
    note.end_time = end_time
    note.instrument = instrument
    note.is_drum = is_drum
    note.program = program
    if end_time > note_sequence.total_time:
      note_sequence.total_time = end_time


def add_chords_to_sequence(note_sequence, chords):
  for figure, time in chords:
    annotation = note_sequence.text_annotations.add()
    annotation.time = time
    annotation.text = figure
    annotation.annotation_type = CHORD_SYMBOL


def add_key_signatures_to_sequence(note_sequence, keys):
  for key, time in keys:
    ks = note_sequence.key_signatures.add()
    ks.time = time
    ks.key = key


def add_beats_to_sequence(note_sequence, beats, beats_per_bar=None):
  for i, time in enumerate(beats):
    annotation = note_sequence.text_annotations.add()
    annotation.time = time
    annotation.annotation_type = BEAT
    if beats_per_bar:
      annotation.text = str((i % beats_per_bar) + 1)


def add_control_changes_to_sequence(note_sequence, instrument, control_changes):
  for time, control_number, control_value in control_changes:
    control_change = note_sequence.control_changes.add()
    control_change.time = time
    control_change.control_number = control_number
    control_change.control_value = control_value
    control_change.instrument = instrument


def add_pitch_bends_to_sequence(
    note_sequence, instrument, program, pitch_bends):
  for time, bend in pitch_bends:
    pitch_bend = note_sequence.pitch_bends.add()
    pitch_bend.time = time
    pitch_bend.bend = bend
    pitch_bend.program = program
    pitch_bend.instrument = instrument
    pitch_bend.is_drum = False  # Assume false for this test method.


def add_quantized_steps_to_sequence(sequence, quantized_steps):
  assert len(sequence.notes) == len(quantized_steps)

  for note, quantized_step in zip(sequence.notes, quantized_steps):
    note.quantized_start_step = quantized_step[0]
    note.quantized_end_step = quantized_step[1]

    if quantized_step[1] > sequence.total_quantized_steps:
      sequence.total_quantized_steps = quantized_step[1]


def add_quantized_chord_steps_to_sequence(sequence, quantized_steps):
  chord_annotations = [a for a in sequence.text_annotations
                       if a.annotation_type == CHORD_SYMBOL]
  assert len(chord_annotations) == len(quantized_steps)
  for chord, quantized_step in zip(chord_annotations, quantized_steps):
    chord.quantized_step = quantized_step


def add_quantized_control_steps_to_sequence(sequence, quantized_steps):
  assert len(sequence.control_changes) == len(quantized_steps)

  for cc, quantized_step in zip(sequence.control_changes, quantized_steps):
    cc.quantized_step = quantized_step


class TrivialOneHotEncoding(encoder_decoder.OneHotEncoding):
  """One-hot encoding that uses the identity encoding."""

  def __init__(self, num_classes, num_steps=None):
    if num_steps is not None and len(num_steps) != num_classes:
      raise ValueError('num_steps must have length num_classes')
    self._num_classes = num_classes
    self._num_steps = num_steps

  @property
  def num_classes(self):
    return self._num_classes

  @property
  def default_event(self):
    return 0

  def encode_event(self, event):
    return event

  def decode_event(self, index):
    event = index
    return event

  def event_to_num_steps(self, event):
    if self._num_steps is not None:
      return self._num_steps[event]
    else:
      return 1


def parse_test_proto(proto_type, proto_string):
  instance = proto_type()
  text_format.Parse(proto_string, instance)
  return instance


def get_testdata_dir():
  dir_path = 'note_seq/testdata'
  return os.path.join(absltest.get_default_test_srcdir(), dir_path)


class ProtoTestCase(absltest.TestCase):
  """Adds assertProtoEquals from tf.test.TestCase."""

  def setUp(self):
    self.maxDiff = None  # pylint:disable=invalid-name
    self.steps_per_quarter = 4
    self.note_sequence = parse_test_proto(
        music_pb2.NoteSequence, """
        time_signatures: {
          numerator: 4
          denominator: 4
        }
        tempos: {
          qpm: 60
        }
        """)
    super().setUp()

  def _AssertProtoEquals(self, a, b, msg=None):  # pylint:disable=invalid-name
    """Asserts that a and b are the same proto.

    Uses ProtoEq() first, as it returns correct results
    for floating point attributes, and then use assertProtoEqual()
    in case of failure as it provides good error messages.

    Args:
      a: a proto.
      b: another proto.
      msg: Optional message to report on failure.
    """
    if not compare.ProtoEq(a, b):
      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)

  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
    """Asserts that message is same as parsed expected_message_ascii.

    Creates another prototype of message, reads the ascii message into it and
    then compares them using self._AssertProtoEqual().

    Args:
      expected_message_maybe_ascii: proto message in original or ascii form.
      message: the message to validate.
      msg: Optional message to report on failure.
    """
    msg = msg if msg else ''
    if isinstance(expected_message_maybe_ascii, type(message)):
      expected_message = expected_message_maybe_ascii
      self._AssertProtoEquals(expected_message, message)
    elif isinstance(expected_message_maybe_ascii, str):
      expected_message = type(message)()
      text_format.Parse(
          expected_message_maybe_ascii,
          expected_message,
          descriptor_pool=descriptor_pool.Default())
      self._AssertProtoEquals(expected_message, message, msg=msg)
    else:
      assert False, ("Can't compare protos of type %s and %s. %s" %
                     (type(expected_message_maybe_ascii), type(message), msg))

Functions

def add_beats_to_sequence(note_sequence, beats, beats_per_bar=None)
Expand source code
def add_beats_to_sequence(note_sequence, beats, beats_per_bar=None):
  for i, time in enumerate(beats):
    annotation = note_sequence.text_annotations.add()
    annotation.time = time
    annotation.annotation_type = BEAT
    if beats_per_bar:
      annotation.text = str((i % beats_per_bar) + 1)
def add_chords_to_sequence(note_sequence, chords)
Expand source code
def add_chords_to_sequence(note_sequence, chords):
  for figure, time in chords:
    annotation = note_sequence.text_annotations.add()
    annotation.time = time
    annotation.text = figure
    annotation.annotation_type = CHORD_SYMBOL
def add_control_changes_to_sequence(note_sequence, instrument, control_changes)
Expand source code
def add_control_changes_to_sequence(note_sequence, instrument, control_changes):
  for time, control_number, control_value in control_changes:
    control_change = note_sequence.control_changes.add()
    control_change.time = time
    control_change.control_number = control_number
    control_change.control_value = control_value
    control_change.instrument = instrument
def add_key_signatures_to_sequence(note_sequence, keys)
Expand source code
def add_key_signatures_to_sequence(note_sequence, keys):
  for key, time in keys:
    ks = note_sequence.key_signatures.add()
    ks.time = time
    ks.key = key
def add_pitch_bends_to_sequence(note_sequence, instrument, program, pitch_bends)
Expand source code
def add_pitch_bends_to_sequence(
    note_sequence, instrument, program, pitch_bends):
  for time, bend in pitch_bends:
    pitch_bend = note_sequence.pitch_bends.add()
    pitch_bend.time = time
    pitch_bend.bend = bend
    pitch_bend.program = program
    pitch_bend.instrument = instrument
    pitch_bend.is_drum = False  # Assume false for this test method.
def add_quantized_chord_steps_to_sequence(sequence, quantized_steps)
Expand source code
def add_quantized_chord_steps_to_sequence(sequence, quantized_steps):
  chord_annotations = [a for a in sequence.text_annotations
                       if a.annotation_type == CHORD_SYMBOL]
  assert len(chord_annotations) == len(quantized_steps)
  for chord, quantized_step in zip(chord_annotations, quantized_steps):
    chord.quantized_step = quantized_step
def add_quantized_control_steps_to_sequence(sequence, quantized_steps)
Expand source code
def add_quantized_control_steps_to_sequence(sequence, quantized_steps):
  assert len(sequence.control_changes) == len(quantized_steps)

  for cc, quantized_step in zip(sequence.control_changes, quantized_steps):
    cc.quantized_step = quantized_step
def add_quantized_steps_to_sequence(sequence, quantized_steps)
Expand source code
def add_quantized_steps_to_sequence(sequence, quantized_steps):
  assert len(sequence.notes) == len(quantized_steps)

  for note, quantized_step in zip(sequence.notes, quantized_steps):
    note.quantized_start_step = quantized_step[0]
    note.quantized_end_step = quantized_step[1]

    if quantized_step[1] > sequence.total_quantized_steps:
      sequence.total_quantized_steps = quantized_step[1]
def add_track_to_sequence(note_sequence, instrument, notes, is_drum=False, program=0)

Adds instrument track to NoteSequence.

Expand source code
def add_track_to_sequence(note_sequence,
                          instrument,
                          notes,
                          is_drum=False,
                          program=0):
  """Adds instrument track to NoteSequence."""
  for pitch, velocity, start_time, end_time in notes:
    note = note_sequence.notes.add()
    note.pitch = pitch
    note.velocity = velocity
    note.start_time = start_time
    note.end_time = end_time
    note.instrument = instrument
    note.is_drum = is_drum
    note.program = program
    if end_time > note_sequence.total_time:
      note_sequence.total_time = end_time
def get_testdata_dir()
Expand source code
def get_testdata_dir():
  dir_path = 'note_seq/testdata'
  return os.path.join(absltest.get_default_test_srcdir(), dir_path)
def parse_test_proto(proto_type, proto_string)
Expand source code
def parse_test_proto(proto_type, proto_string):
  instance = proto_type()
  text_format.Parse(proto_string, instance)
  return instance

Classes

class ProtoTestCase (*args, **kwargs)

Adds assertProtoEquals from tf.test.TestCase.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class ProtoTestCase(absltest.TestCase):
  """Adds assertProtoEquals from tf.test.TestCase."""

  def setUp(self):
    self.maxDiff = None  # pylint:disable=invalid-name
    self.steps_per_quarter = 4
    self.note_sequence = parse_test_proto(
        music_pb2.NoteSequence, """
        time_signatures: {
          numerator: 4
          denominator: 4
        }
        tempos: {
          qpm: 60
        }
        """)
    super().setUp()

  def _AssertProtoEquals(self, a, b, msg=None):  # pylint:disable=invalid-name
    """Asserts that a and b are the same proto.

    Uses ProtoEq() first, as it returns correct results
    for floating point attributes, and then use assertProtoEqual()
    in case of failure as it provides good error messages.

    Args:
      a: a proto.
      b: another proto.
      msg: Optional message to report on failure.
    """
    if not compare.ProtoEq(a, b):
      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)

  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
    """Asserts that message is same as parsed expected_message_ascii.

    Creates another prototype of message, reads the ascii message into it and
    then compares them using self._AssertProtoEqual().

    Args:
      expected_message_maybe_ascii: proto message in original or ascii form.
      message: the message to validate.
      msg: Optional message to report on failure.
    """
    msg = msg if msg else ''
    if isinstance(expected_message_maybe_ascii, type(message)):
      expected_message = expected_message_maybe_ascii
      self._AssertProtoEquals(expected_message, message)
    elif isinstance(expected_message_maybe_ascii, str):
      expected_message = type(message)()
      text_format.Parse(
          expected_message_maybe_ascii,
          expected_message,
          descriptor_pool=descriptor_pool.Default())
      self._AssertProtoEquals(expected_message, message, msg=msg)
    else:
      assert False, ("Can't compare protos of type %s and %s. %s" %
                     (type(expected_message_maybe_ascii), type(message), msg))

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Subclasses

Methods

def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None)

Asserts that message is same as parsed expected_message_ascii.

Creates another prototype of message, reads the ascii message into it and then compares them using self._AssertProtoEqual().

Args

expected_message_maybe_ascii
proto message in original or ascii form.
message
the message to validate.
msg
Optional message to report on failure.
Expand source code
def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
  """Asserts that message is same as parsed expected_message_ascii.

  Creates another prototype of message, reads the ascii message into it and
  then compares them using self._AssertProtoEqual().

  Args:
    expected_message_maybe_ascii: proto message in original or ascii form.
    message: the message to validate.
    msg: Optional message to report on failure.
  """
  msg = msg if msg else ''
  if isinstance(expected_message_maybe_ascii, type(message)):
    expected_message = expected_message_maybe_ascii
    self._AssertProtoEquals(expected_message, message)
  elif isinstance(expected_message_maybe_ascii, str):
    expected_message = type(message)()
    text_format.Parse(
        expected_message_maybe_ascii,
        expected_message,
        descriptor_pool=descriptor_pool.Default())
    self._AssertProtoEquals(expected_message, message, msg=msg)
  else:
    assert False, ("Can't compare protos of type %s and %s. %s" %
                   (type(expected_message_maybe_ascii), type(message), msg))
def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  self.maxDiff = None  # pylint:disable=invalid-name
  self.steps_per_quarter = 4
  self.note_sequence = parse_test_proto(
      music_pb2.NoteSequence, """
      time_signatures: {
        numerator: 4
        denominator: 4
      }
      tempos: {
        qpm: 60
      }
      """)
  super().setUp()
class TrivialOneHotEncoding (num_classes, num_steps=None)

One-hot encoding that uses the identity encoding.

Expand source code
class TrivialOneHotEncoding(encoder_decoder.OneHotEncoding):
  """One-hot encoding that uses the identity encoding."""

  def __init__(self, num_classes, num_steps=None):
    if num_steps is not None and len(num_steps) != num_classes:
      raise ValueError('num_steps must have length num_classes')
    self._num_classes = num_classes
    self._num_steps = num_steps

  @property
  def num_classes(self):
    return self._num_classes

  @property
  def default_event(self):
    return 0

  def encode_event(self, event):
    return event

  def decode_event(self, index):
    event = index
    return event

  def event_to_num_steps(self, event):
    if self._num_steps is not None:
      return self._num_steps[event]
    else:
      return 1

Ancestors

Inherited members