Module note_seq.drums_encoder_decoder_test
Tests for drums_encoder_decoder.
Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for drums_encoder_decoder."""
from absl.testing import absltest
from note_seq import drums_encoder_decoder
DRUMS = lambda *args: frozenset(args)
NO_DRUMS = frozenset()
def _index_to_binary(index):
fmt = '%%0%dd' % len(drums_encoder_decoder.DEFAULT_DRUM_TYPE_PITCHES)
return fmt % int(bin(index)[2:])
class MultiDrumOneHotEncodingTest(absltest.TestCase):
def setUp(self):
self.enc = drums_encoder_decoder.MultiDrumOneHotEncoding()
def testEncode(self):
# No drums should encode to zero.
index = self.enc.encode_event(NO_DRUMS)
self.assertEqual(0, index)
# Single drum should encode to single bit active, different for different
# drum types.
index1 = self.enc.encode_event(DRUMS(35))
index2 = self.enc.encode_event(DRUMS(44))
self.assertEqual(1, _index_to_binary(index1).count('1'))
self.assertEqual(1, _index_to_binary(index2).count('1'))
self.assertNotEqual(index1, index2)
# Multiple drums should encode to multiple bits active, one for each drum
# type.
index = self.enc.encode_event(DRUMS(40, 44))
self.assertEqual(2, _index_to_binary(index).count('1'))
index = self.enc.encode_event(DRUMS(35, 51, 59))
self.assertEqual(2, _index_to_binary(index).count('1'))
def testDecode(self):
# Zero should decode to no drums.
event = self.enc.decode_event(0)
self.assertEqual(NO_DRUMS, event)
# Single bit active should encode to single drum, different for different
# bits.
event1 = self.enc.decode_event(1)
event2 = self.enc.decode_event(
2 ** (len(drums_encoder_decoder.DEFAULT_DRUM_TYPE_PITCHES) // 2))
self.assertEqual(frozenset, type(event1))
self.assertEqual(frozenset, type(event2))
self.assertLen(event1, 1)
self.assertLen(event2, 1)
self.assertNotEqual(event1, event2)
# Multiple bits active should encode to multiple drums.
event = self.enc.decode_event(7)
self.assertEqual(frozenset, type(event))
self.assertLen(event, 3)
if __name__ == '__main__':
absltest.main()
Functions
def DRUMS(*args)
-
Expand source code
DRUMS = lambda *args: frozenset(args)
Classes
class MultiDrumOneHotEncodingTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class MultiDrumOneHotEncodingTest(absltest.TestCase): def setUp(self): self.enc = drums_encoder_decoder.MultiDrumOneHotEncoding() def testEncode(self): # No drums should encode to zero. index = self.enc.encode_event(NO_DRUMS) self.assertEqual(0, index) # Single drum should encode to single bit active, different for different # drum types. index1 = self.enc.encode_event(DRUMS(35)) index2 = self.enc.encode_event(DRUMS(44)) self.assertEqual(1, _index_to_binary(index1).count('1')) self.assertEqual(1, _index_to_binary(index2).count('1')) self.assertNotEqual(index1, index2) # Multiple drums should encode to multiple bits active, one for each drum # type. index = self.enc.encode_event(DRUMS(40, 44)) self.assertEqual(2, _index_to_binary(index).count('1')) index = self.enc.encode_event(DRUMS(35, 51, 59)) self.assertEqual(2, _index_to_binary(index).count('1')) def testDecode(self): # Zero should decode to no drums. event = self.enc.decode_event(0) self.assertEqual(NO_DRUMS, event) # Single bit active should encode to single drum, different for different # bits. event1 = self.enc.decode_event(1) event2 = self.enc.decode_event( 2 ** (len(drums_encoder_decoder.DEFAULT_DRUM_TYPE_PITCHES) // 2)) self.assertEqual(frozenset, type(event1)) self.assertEqual(frozenset, type(event2)) self.assertLen(event1, 1) self.assertLen(event2, 1) self.assertNotEqual(event1, event2) # Multiple bits active should encode to multiple drums. event = self.enc.decode_event(7) self.assertEqual(frozenset, type(event)) self.assertLen(event, 3)
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): self.enc = drums_encoder_decoder.MultiDrumOneHotEncoding()
def testDecode(self)
-
Expand source code
def testDecode(self): # Zero should decode to no drums. event = self.enc.decode_event(0) self.assertEqual(NO_DRUMS, event) # Single bit active should encode to single drum, different for different # bits. event1 = self.enc.decode_event(1) event2 = self.enc.decode_event( 2 ** (len(drums_encoder_decoder.DEFAULT_DRUM_TYPE_PITCHES) // 2)) self.assertEqual(frozenset, type(event1)) self.assertEqual(frozenset, type(event2)) self.assertLen(event1, 1) self.assertLen(event2, 1) self.assertNotEqual(event1, event2) # Multiple bits active should encode to multiple drums. event = self.enc.decode_event(7) self.assertEqual(frozenset, type(event)) self.assertLen(event, 3)
def testEncode(self)
-
Expand source code
def testEncode(self): # No drums should encode to zero. index = self.enc.encode_event(NO_DRUMS) self.assertEqual(0, index) # Single drum should encode to single bit active, different for different # drum types. index1 = self.enc.encode_event(DRUMS(35)) index2 = self.enc.encode_event(DRUMS(44)) self.assertEqual(1, _index_to_binary(index1).count('1')) self.assertEqual(1, _index_to_binary(index2).count('1')) self.assertNotEqual(index1, index2) # Multiple drums should encode to multiple bits active, one for each drum # type. index = self.enc.encode_event(DRUMS(40, 44)) self.assertEqual(2, _index_to_binary(index).count('1')) index = self.enc.encode_event(DRUMS(35, 51, 59)) self.assertEqual(2, _index_to_binary(index).count('1'))