Module note_seq.drums_encoder_decoder

Classes for converting between drum tracks and models inputs/outputs.

Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for converting between drum tracks and models inputs/outputs."""

from note_seq import encoder_decoder

# Default list of 9 drum types, where each type is represented by a list of
# MIDI pitches for drum sounds belonging to that type. This default list
# attempts to map all GM1 and GM2 drums onto a much smaller standard drum kit
# based on drum sound and function.
DEFAULT_DRUM_TYPE_PITCHES = [
    # kick drum
    [36, 35],

    # snare drum
    [38, 27, 28, 31, 32, 33, 34, 37, 39, 40, 56, 65, 66, 75, 85],

    # closed hi-hat
    [42, 44, 54, 68, 69, 70, 71, 73, 78, 80, 22],

    # open hi-hat
    [46, 67, 72, 74, 79, 81, 26],

    # low tom
    [45, 29, 41, 43, 61, 64, 84],

    # mid tom
    [48, 47, 60, 63, 77, 86, 87],

    # high tom
    [50, 30, 62, 76, 83],

    # crash cymbal
    [49, 52, 55, 57, 58],

    # ride cymbal
    [51, 53, 59, 82]
]


class DrumsEncodingError(Exception):
  pass


class MultiDrumOneHotEncoding(encoder_decoder.OneHotEncoding):
  """Encodes drum events as binary where each bit is a different drum type.

  Each event consists of multiple simultaneous drum "pitches". This encoding
  converts each pitch to a drum type, e.g. bass drum, hi-hat, etc. Each drum
  type is mapped to a single bit of a binary integer representation, where the
  bit has value 0 if the drum type is not present, and 1 if it is present.

  If multiple "pitches" corresponding to the same drum type (e.g. two different
  ride cymbals) are present, the encoding is the same as if only one of them
  were present.
  """

  def __init__(self, drum_type_pitches=None, ignore_unknown_drums=True):
    """Initializes the MultiDrumOneHotEncoding.

    Args:
      drum_type_pitches: A Python list of the MIDI pitch values for each drum
          type. If None, `DEFAULT_DRUM_TYPE_PITCHES` will be used.
      ignore_unknown_drums: If True, unknown drum pitches will not be encoded.
          If False, a DrumsEncodingError will be raised when unknown drum
          pitches are encountered.
    """
    if drum_type_pitches is None:
      drum_type_pitches = DEFAULT_DRUM_TYPE_PITCHES
    self._drum_map = dict(enumerate(drum_type_pitches))
    self._inverse_drum_map = dict((pitch, index)
                                  for index, pitches in self._drum_map.items()
                                  for pitch in pitches)
    self._ignore_unknown_drums = ignore_unknown_drums

  @property
  def num_classes(self):
    return 2 ** len(self._drum_map)

  @property
  def default_event(self):
    return frozenset()

  def encode_event(self, event):
    drum_type_indices = set()
    for pitch in event:
      if pitch in self._inverse_drum_map:
        drum_type_indices.add(self._inverse_drum_map[pitch])
      elif not self._ignore_unknown_drums:
        raise DrumsEncodingError('unknown drum pitch: %d' % pitch)
    return sum(2 ** i for i in drum_type_indices)

  def decode_event(self, index):
    bits = reversed(str(bin(index)))
    # Use the first "pitch" for each drum type.
    return frozenset(self._drum_map[i][0]
                     for i, b in enumerate(bits)
                     if b == '1')

Classes

class DrumsEncodingError (*args, **kwargs)

Common base class for all non-exit exceptions.

Expand source code
class DrumsEncodingError(Exception):
  pass

Ancestors

  • builtins.Exception
  • builtins.BaseException
class MultiDrumOneHotEncoding (drum_type_pitches=None, ignore_unknown_drums=True)

Encodes drum events as binary where each bit is a different drum type.

Each event consists of multiple simultaneous drum "pitches". This encoding converts each pitch to a drum type, e.g. bass drum, hi-hat, etc. Each drum type is mapped to a single bit of a binary integer representation, where the bit has value 0 if the drum type is not present, and 1 if it is present.

If multiple "pitches" corresponding to the same drum type (e.g. two different ride cymbals) are present, the encoding is the same as if only one of them were present.

Initializes the MultiDrumOneHotEncoding.

Args

drum_type_pitches
A Python list of the MIDI pitch values for each drum type. If None, DEFAULT_DRUM_TYPE_PITCHES will be used.
ignore_unknown_drums
If True, unknown drum pitches will not be encoded. If False, a DrumsEncodingError will be raised when unknown drum pitches are encountered.
Expand source code
class MultiDrumOneHotEncoding(encoder_decoder.OneHotEncoding):
  """Encodes drum events as binary where each bit is a different drum type.

  Each event consists of multiple simultaneous drum "pitches". This encoding
  converts each pitch to a drum type, e.g. bass drum, hi-hat, etc. Each drum
  type is mapped to a single bit of a binary integer representation, where the
  bit has value 0 if the drum type is not present, and 1 if it is present.

  If multiple "pitches" corresponding to the same drum type (e.g. two different
  ride cymbals) are present, the encoding is the same as if only one of them
  were present.
  """

  def __init__(self, drum_type_pitches=None, ignore_unknown_drums=True):
    """Initializes the MultiDrumOneHotEncoding.

    Args:
      drum_type_pitches: A Python list of the MIDI pitch values for each drum
          type. If None, `DEFAULT_DRUM_TYPE_PITCHES` will be used.
      ignore_unknown_drums: If True, unknown drum pitches will not be encoded.
          If False, a DrumsEncodingError will be raised when unknown drum
          pitches are encountered.
    """
    if drum_type_pitches is None:
      drum_type_pitches = DEFAULT_DRUM_TYPE_PITCHES
    self._drum_map = dict(enumerate(drum_type_pitches))
    self._inverse_drum_map = dict((pitch, index)
                                  for index, pitches in self._drum_map.items()
                                  for pitch in pitches)
    self._ignore_unknown_drums = ignore_unknown_drums

  @property
  def num_classes(self):
    return 2 ** len(self._drum_map)

  @property
  def default_event(self):
    return frozenset()

  def encode_event(self, event):
    drum_type_indices = set()
    for pitch in event:
      if pitch in self._inverse_drum_map:
        drum_type_indices.add(self._inverse_drum_map[pitch])
      elif not self._ignore_unknown_drums:
        raise DrumsEncodingError('unknown drum pitch: %d' % pitch)
    return sum(2 ** i for i in drum_type_indices)

  def decode_event(self, index):
    bits = reversed(str(bin(index)))
    # Use the first "pitch" for each drum type.
    return frozenset(self._drum_map[i][0]
                     for i, b in enumerate(bits)
                     if b == '1')

Ancestors

Inherited members