Module note_seq.pianoroll_encoder_decoder_test
Tests for pianoroll_encoder_decoder.
Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pianoroll_encoder_decoder."""
from absl.testing import absltest
from note_seq import pianoroll_encoder_decoder
import numpy as np
class PianorollEncodingTest(absltest.TestCase):
def setUp(self):
self.enc = pianoroll_encoder_decoder.PianorollEncoderDecoder(5)
def testProperties(self):
self.assertEqual(5, self.enc.input_size)
self.assertEqual(32, self.enc.num_classes)
self.assertEqual(0, self.enc.default_event_label)
def testEncodeInput(self):
events = [(), (1, 2), (2,)]
self.assertTrue(np.array_equal(
np.zeros(5, np.bool), self.enc.events_to_input(events, 0)))
self.assertTrue(np.array_equal(
[0, 1, 1, 0, 0], self.enc.events_to_input(events, 1)))
self.assertTrue(np.array_equal(
[0, 0, 1, 0, 0], self.enc.events_to_input(events, 2)))
def testEncodeLabel(self):
events = [[], [1, 2], [2]]
self.assertEqual(0, self.enc.events_to_label(events, 0))
self.assertEqual(6, self.enc.events_to_label(events, 1))
self.assertEqual(4, self.enc.events_to_label(events, 2))
def testDecodeLabel(self):
self.assertEqual((), self.enc.class_index_to_event(0, None))
self.assertEqual((1, 2), self.enc.class_index_to_event(6, None))
self.assertEqual((2,), self.enc.class_index_to_event(4, None))
def testExtendEventSequences(self):
seqs = ([(0,), (1, 2)], [(), ()])
samples = ([0, 0, 0, 0, 0], [1, 1, 0, 0, 1])
self.enc.extend_event_sequences(seqs, samples)
self.assertEqual(([(0,), (1, 2), ()], [(), (), (0, 1, 4)]), seqs)
if __name__ == '__main__':
absltest.main()
Classes
class PianorollEncodingTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class PianorollEncodingTest(absltest.TestCase): def setUp(self): self.enc = pianoroll_encoder_decoder.PianorollEncoderDecoder(5) def testProperties(self): self.assertEqual(5, self.enc.input_size) self.assertEqual(32, self.enc.num_classes) self.assertEqual(0, self.enc.default_event_label) def testEncodeInput(self): events = [(), (1, 2), (2,)] self.assertTrue(np.array_equal( np.zeros(5, np.bool), self.enc.events_to_input(events, 0))) self.assertTrue(np.array_equal( [0, 1, 1, 0, 0], self.enc.events_to_input(events, 1))) self.assertTrue(np.array_equal( [0, 0, 1, 0, 0], self.enc.events_to_input(events, 2))) def testEncodeLabel(self): events = [[], [1, 2], [2]] self.assertEqual(0, self.enc.events_to_label(events, 0)) self.assertEqual(6, self.enc.events_to_label(events, 1)) self.assertEqual(4, self.enc.events_to_label(events, 2)) def testDecodeLabel(self): self.assertEqual((), self.enc.class_index_to_event(0, None)) self.assertEqual((1, 2), self.enc.class_index_to_event(6, None)) self.assertEqual((2,), self.enc.class_index_to_event(4, None)) def testExtendEventSequences(self): seqs = ([(0,), (1, 2)], [(), ()]) samples = ([0, 0, 0, 0, 0], [1, 1, 0, 0, 1]) self.enc.extend_event_sequences(seqs, samples) self.assertEqual(([(0,), (1, 2), ()], [(), (), (0, 1, 4)]), seqs)
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): self.enc = pianoroll_encoder_decoder.PianorollEncoderDecoder(5)
def testDecodeLabel(self)
-
Expand source code
def testDecodeLabel(self): self.assertEqual((), self.enc.class_index_to_event(0, None)) self.assertEqual((1, 2), self.enc.class_index_to_event(6, None)) self.assertEqual((2,), self.enc.class_index_to_event(4, None))
def testEncodeInput(self)
-
Expand source code
def testEncodeInput(self): events = [(), (1, 2), (2,)] self.assertTrue(np.array_equal( np.zeros(5, np.bool), self.enc.events_to_input(events, 0))) self.assertTrue(np.array_equal( [0, 1, 1, 0, 0], self.enc.events_to_input(events, 1))) self.assertTrue(np.array_equal( [0, 0, 1, 0, 0], self.enc.events_to_input(events, 2)))
def testEncodeLabel(self)
-
Expand source code
def testEncodeLabel(self): events = [[], [1, 2], [2]] self.assertEqual(0, self.enc.events_to_label(events, 0)) self.assertEqual(6, self.enc.events_to_label(events, 1)) self.assertEqual(4, self.enc.events_to_label(events, 2))
def testExtendEventSequences(self)
-
Expand source code
def testExtendEventSequences(self): seqs = ([(0,), (1, 2)], [(), ()]) samples = ([0, 0, 0, 0, 0], [1, 1, 0, 0, 1]) self.enc.extend_event_sequences(seqs, samples) self.assertEqual(([(0,), (1, 2), ()], [(), (), (0, 1, 4)]), seqs)
def testProperties(self)
-
Expand source code
def testProperties(self): self.assertEqual(5, self.enc.input_size) self.assertEqual(32, self.enc.num_classes) self.assertEqual(0, self.enc.default_event_label)