Module note_seq.encoder_decoder_test

Tests for encoder_decoder.

Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for encoder_decoder."""

from absl.testing import absltest
from note_seq import encoder_decoder
from note_seq import testing_lib
import numpy as np


class OneHotEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))

  def testInputSize(self):
    self.assertEqual(3, self.enc.input_size)

  def testNumClasses(self):
    self.assertEqual(3, self.enc.num_classes)

  def testEventsToInput(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0], self.enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0], self.enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))

  def testEventsToLabel(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.events_to_label(events, 0))
    self.assertEqual(1, self.enc.events_to_label(events, 1))
    self.assertEqual(0, self.enc.events_to_label(events, 2))
    self.assertEqual(2, self.enc.events_to_label(events, 3))
    self.assertEqual(0, self.enc.events_to_label(events, 4))

  def testClassIndexToEvent(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.class_index_to_event(0, events))
    self.assertEqual(1, self.enc.class_index_to_event(1, events))
    self.assertEqual(2, self.enc.class_index_to_event(2, events))

  def testLabelsToNumSteps(self):
    labels = [0, 1, 0, 2, 0]
    self.assertEqual(3, self.enc.labels_to_num_steps(labels))

  def testEncode(self):
    events = [0, 1, 0, 2, 0]
    inputs, labels = self.enc.encode(events)
    expected_inputs = [[1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0],
                       [1.0, 0.0, 0.0],
                       [0.0, 0.0, 1.0]]
    expected_labels = [1, 0, 2, 0]
    self.assertEqual(inputs, expected_inputs)
    self.assertEqual(labels, expected_labels)

  def testGetInputsBatch(self):
    event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
    expected_inputs_1 = [[1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0],
                         [1.0, 0.0, 0.0],
                         [0.0, 0.0, 1.0],
                         [1.0, 0.0, 0.0]]
    expected_inputs_2 = [[1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0],
                         [0.0, 0.0, 1.0]]
    expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
    expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                        expected_inputs_2[-1:]]
    self.assertListEqual(
        expected_full_length_inputs_batch,
        self.enc.get_inputs_batch(event_sequences, True))
    self.assertListEqual(
        expected_last_event_inputs_batch,
        self.enc.get_inputs_batch(event_sequences))

  def testExtendEventSequences(self):
    events1 = [0]
    events2 = [0]
    events3 = [0]
    event_sequences = [events1, events2, events3]
    softmax = [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]
    self.enc.extend_event_sequences(event_sequences, softmax)
    self.assertListEqual(list(events1), [0, 2])
    self.assertListEqual(list(events2), [0, 0])
    self.assertListEqual(list(events3), [0, 1])

  def testEvaluateLogLikelihood(self):
    events1 = [0, 1, 0]
    events2 = [1, 2, 2]
    event_sequences = [events1, events2]
    softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
               [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
    p = self.enc.evaluate_log_likelihood(event_sequences, softmax)
    self.assertListEqual([np.log(0.5) + np.log(0.3),
                          np.log(0.4) + np.log(0.6)], p)


class OneHotIndexEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.OneHotIndexEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))

  def testInputSize(self):
    self.assertEqual(1, self.enc.input_size)

  def testInputDepth(self):
    self.assertEqual(3, self.enc.input_depth)

  def testEventsToInput(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual([0], self.enc.events_to_input(events, 0))
    self.assertEqual([1], self.enc.events_to_input(events, 1))
    self.assertEqual([0], self.enc.events_to_input(events, 2))
    self.assertEqual([2], self.enc.events_to_input(events, 3))
    self.assertEqual([0], self.enc.events_to_input(events, 4))

  def testEncode(self):
    events = [0, 1, 0, 2, 0]
    inputs, labels = self.enc.encode(events)
    expected_inputs = [[0], [1], [0], [2]]
    expected_labels = [1, 0, 2, 0]
    self.assertEqual(inputs, expected_inputs)
    self.assertEqual(labels, expected_labels)

  def testGetInputsBatch(self):
    event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
    expected_inputs_1 = [[0], [1], [0], [2], [0]]
    expected_inputs_2 = [[0], [1], [2]]
    expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
    expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                        expected_inputs_2[-1:]]
    self.assertListEqual(
        expected_full_length_inputs_batch,
        self.enc.get_inputs_batch(event_sequences, True))
    self.assertListEqual(
        expected_last_event_inputs_batch,
        self.enc.get_inputs_batch(event_sequences))


class LookbackEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)), [1, 2], 2)

  def testInputSize(self):
    self.assertEqual(13, self.enc.input_size)

  def testNumClasses(self):
    self.assertEqual(5, self.enc.num_classes)

  def testEventsToInput(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
                      1.0, -1.0, 0.0, 0.0],
                     self.enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0,
                      -1.0, 1.0, 0.0, 0.0],
                     self.enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0,
                      1.0, 1.0, 0.0, 1.0],
                     self.enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,
                      -1.0, -1.0, 0.0, 0.0],
                     self.enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
                      1.0, -1.0, 0.0, 1.0],
                     self.enc.events_to_input(events, 4))

  def testEventsToLabel(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(4, self.enc.events_to_label(events, 0))
    self.assertEqual(1, self.enc.events_to_label(events, 1))
    self.assertEqual(4, self.enc.events_to_label(events, 2))
    self.assertEqual(2, self.enc.events_to_label(events, 3))
    self.assertEqual(4, self.enc.events_to_label(events, 4))

  def testClassIndexToEvent(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(3, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(4, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(3, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(4, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(3, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(4, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(3, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(4, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
    self.assertEqual(0, self.enc.class_index_to_event(3, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(4, events[:5]))

  def testLabelsToNumSteps(self):
    labels = [0, 1, 0, 2, 0]
    self.assertEqual(3, self.enc.labels_to_num_steps(labels))

    labels = [0, 1, 3, 2, 4]
    self.assertEqual(5, self.enc.labels_to_num_steps(labels))

  def testEmptyLookback(self):
    enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3), [], 2)
    self.assertEqual(5, enc.input_size)
    self.assertEqual(3, enc.num_classes)

    events = [0, 1, 0, 2, 0]

    self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                     enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0],
                     enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0],
                     enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0],
                     enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                     enc.events_to_input(events, 4))

    self.assertEqual(0, enc.events_to_label(events, 0))
    self.assertEqual(1, enc.events_to_label(events, 1))
    self.assertEqual(0, enc.events_to_label(events, 2))
    self.assertEqual(2, enc.events_to_label(events, 3))
    self.assertEqual(0, enc.events_to_label(events, 4))

    self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))


class ConditionalEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder(
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(2)),
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(3)))

  def testInputSize(self):
    self.assertEqual(5, self.enc.input_size)

  def testNumClasses(self):
    self.assertEqual(3, self.enc.num_classes)

  def testEventsToInput(self):
    control_events = [1, 1, 1, 0, 0]
    target_events = [0, 1, 0, 2, 0]
    self.assertEqual(
        [0.0, 1.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(control_events, target_events, 0))
    self.assertEqual(
        [0.0, 1.0, 0.0, 1.0, 0.0],
        self.enc.events_to_input(control_events, target_events, 1))
    self.assertEqual(
        [1.0, 0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(control_events, target_events, 2))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0, 1.0],
        self.enc.events_to_input(control_events, target_events, 3))

  def testEventsToLabel(self):
    target_events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.events_to_label(target_events, 0))
    self.assertEqual(1, self.enc.events_to_label(target_events, 1))
    self.assertEqual(0, self.enc.events_to_label(target_events, 2))
    self.assertEqual(2, self.enc.events_to_label(target_events, 3))
    self.assertEqual(0, self.enc.events_to_label(target_events, 4))

  def testClassIndexToEvent(self):
    target_events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.class_index_to_event(0, target_events))
    self.assertEqual(1, self.enc.class_index_to_event(1, target_events))
    self.assertEqual(2, self.enc.class_index_to_event(2, target_events))

  def testEncode(self):
    control_events = [1, 1, 1, 0, 0]
    target_events = [0, 1, 0, 2, 0]
    inputs, labels = self.enc.encode(control_events, target_events)
    expected_inputs = [[0.0, 1.0, 1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0, 1.0, 0.0],
                       [1.0, 0.0, 1.0, 0.0, 0.0],
                       [1.0, 0.0, 0.0, 0.0, 1.0]]
    expected_labels = [1, 0, 2, 0]
    self.assertEqual(inputs, expected_inputs)
    self.assertEqual(labels, expected_labels)

  def testGetInputsBatch(self):
    control_event_sequences = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]
    target_event_sequences = [[0, 1, 0, 2], [0, 1]]
    expected_inputs_1 = [[0.0, 1.0, 1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0, 1.0, 0.0],
                         [1.0, 0.0, 1.0, 0.0, 0.0],
                         [1.0, 0.0, 0.0, 0.0, 1.0]]
    expected_inputs_2 = [[0.0, 1.0, 1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0, 1.0, 0.0]]
    expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
    expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                        expected_inputs_2[-1:]]
    self.assertListEqual(
        expected_full_length_inputs_batch,
        self.enc.get_inputs_batch(
            control_event_sequences, target_event_sequences, True))
    self.assertListEqual(
        expected_last_event_inputs_batch,
        self.enc.get_inputs_batch(
            control_event_sequences, target_event_sequences))

  def testExtendEventSequences(self):
    target_events_1 = [0]
    target_events_2 = [0]
    target_events_3 = [0]
    target_event_sequences = [target_events_1, target_events_2, target_events_3]
    softmax = np.array(
        [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]])
    self.enc.extend_event_sequences(target_event_sequences, softmax)
    self.assertListEqual(list(target_events_1), [0, 2])
    self.assertListEqual(list(target_events_2), [0, 0])
    self.assertListEqual(list(target_events_3), [0, 1])

  def testEvaluateLogLikelihood(self):
    target_events_1 = [0, 1, 0]
    target_events_2 = [1, 2, 2]
    target_event_sequences = [target_events_1, target_events_2]
    softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
               [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
    p = self.enc.evaluate_log_likelihood(target_event_sequences, softmax)
    self.assertListEqual([np.log(0.5) + np.log(0.3),
                          np.log(0.4) + np.log(0.6)], p)


class OptionalEventSequenceEncoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.OptionalEventSequenceEncoder(
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(3)))

  def testInputSize(self):
    self.assertEqual(4, self.enc.input_size)

  def testEventsToInput(self):
    events = [(False, 0), (False, 1), (False, 0), (True, 2), (True, 0)]
    self.assertEqual(
        [0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 0))
    self.assertEqual(
        [0.0, 0.0, 1.0, 0.0],
        self.enc.events_to_input(events, 1))
    self.assertEqual(
        [0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 2))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0],
        self.enc.events_to_input(events, 3))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0],
        self.enc.events_to_input(events, 4))


class MultipleEventSequenceEncoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.MultipleEventSequenceEncoder([
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(2)),
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(3))])

  def testInputSize(self):
    self.assertEqual(5, self.enc.input_size)

  def testEventsToInput(self):
    events = [(1, 0), (1, 1), (1, 0), (0, 2), (0, 0)]
    self.assertEqual(
        [0.0, 1.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 0))
    self.assertEqual(
        [0.0, 1.0, 0.0, 1.0, 0.0],
        self.enc.events_to_input(events, 1))
    self.assertEqual(
        [0.0, 1.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 2))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0, 1.0],
        self.enc.events_to_input(events, 3))
    self.assertEqual(
        [1.0, 0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 4))


if __name__ == '__main__':
  absltest.main()

Classes

class ConditionalEventSequenceEncoderDecoderTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class ConditionalEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder(
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(2)),
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(3)))

  def testInputSize(self):
    self.assertEqual(5, self.enc.input_size)

  def testNumClasses(self):
    self.assertEqual(3, self.enc.num_classes)

  def testEventsToInput(self):
    control_events = [1, 1, 1, 0, 0]
    target_events = [0, 1, 0, 2, 0]
    self.assertEqual(
        [0.0, 1.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(control_events, target_events, 0))
    self.assertEqual(
        [0.0, 1.0, 0.0, 1.0, 0.0],
        self.enc.events_to_input(control_events, target_events, 1))
    self.assertEqual(
        [1.0, 0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(control_events, target_events, 2))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0, 1.0],
        self.enc.events_to_input(control_events, target_events, 3))

  def testEventsToLabel(self):
    target_events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.events_to_label(target_events, 0))
    self.assertEqual(1, self.enc.events_to_label(target_events, 1))
    self.assertEqual(0, self.enc.events_to_label(target_events, 2))
    self.assertEqual(2, self.enc.events_to_label(target_events, 3))
    self.assertEqual(0, self.enc.events_to_label(target_events, 4))

  def testClassIndexToEvent(self):
    target_events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.class_index_to_event(0, target_events))
    self.assertEqual(1, self.enc.class_index_to_event(1, target_events))
    self.assertEqual(2, self.enc.class_index_to_event(2, target_events))

  def testEncode(self):
    control_events = [1, 1, 1, 0, 0]
    target_events = [0, 1, 0, 2, 0]
    inputs, labels = self.enc.encode(control_events, target_events)
    expected_inputs = [[0.0, 1.0, 1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0, 1.0, 0.0],
                       [1.0, 0.0, 1.0, 0.0, 0.0],
                       [1.0, 0.0, 0.0, 0.0, 1.0]]
    expected_labels = [1, 0, 2, 0]
    self.assertEqual(inputs, expected_inputs)
    self.assertEqual(labels, expected_labels)

  def testGetInputsBatch(self):
    control_event_sequences = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]
    target_event_sequences = [[0, 1, 0, 2], [0, 1]]
    expected_inputs_1 = [[0.0, 1.0, 1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0, 1.0, 0.0],
                         [1.0, 0.0, 1.0, 0.0, 0.0],
                         [1.0, 0.0, 0.0, 0.0, 1.0]]
    expected_inputs_2 = [[0.0, 1.0, 1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0, 1.0, 0.0]]
    expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
    expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                        expected_inputs_2[-1:]]
    self.assertListEqual(
        expected_full_length_inputs_batch,
        self.enc.get_inputs_batch(
            control_event_sequences, target_event_sequences, True))
    self.assertListEqual(
        expected_last_event_inputs_batch,
        self.enc.get_inputs_batch(
            control_event_sequences, target_event_sequences))

  def testExtendEventSequences(self):
    target_events_1 = [0]
    target_events_2 = [0]
    target_events_3 = [0]
    target_event_sequences = [target_events_1, target_events_2, target_events_3]
    softmax = np.array(
        [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]])
    self.enc.extend_event_sequences(target_event_sequences, softmax)
    self.assertListEqual(list(target_events_1), [0, 2])
    self.assertListEqual(list(target_events_2), [0, 0])
    self.assertListEqual(list(target_events_3), [0, 1])

  def testEvaluateLogLikelihood(self):
    target_events_1 = [0, 1, 0]
    target_events_2 = [1, 2, 2]
    target_event_sequences = [target_events_1, target_events_2]
    softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
               [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
    p = self.enc.evaluate_log_likelihood(target_event_sequences, softmax)
    self.assertListEqual([np.log(0.5) + np.log(0.3),
                          np.log(0.4) + np.log(0.6)], p)

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder(
      encoder_decoder.OneHotEventSequenceEncoderDecoder(
          testing_lib.TrivialOneHotEncoding(2)),
      encoder_decoder.OneHotEventSequenceEncoderDecoder(
          testing_lib.TrivialOneHotEncoding(3)))
def testClassIndexToEvent(self)
Expand source code
def testClassIndexToEvent(self):
  target_events = [0, 1, 0, 2, 0]
  self.assertEqual(0, self.enc.class_index_to_event(0, target_events))
  self.assertEqual(1, self.enc.class_index_to_event(1, target_events))
  self.assertEqual(2, self.enc.class_index_to_event(2, target_events))
def testEncode(self)
Expand source code
def testEncode(self):
  control_events = [1, 1, 1, 0, 0]
  target_events = [0, 1, 0, 2, 0]
  inputs, labels = self.enc.encode(control_events, target_events)
  expected_inputs = [[0.0, 1.0, 1.0, 0.0, 0.0],
                     [0.0, 1.0, 0.0, 1.0, 0.0],
                     [1.0, 0.0, 1.0, 0.0, 0.0],
                     [1.0, 0.0, 0.0, 0.0, 1.0]]
  expected_labels = [1, 0, 2, 0]
  self.assertEqual(inputs, expected_inputs)
  self.assertEqual(labels, expected_labels)
def testEvaluateLogLikelihood(self)
Expand source code
def testEvaluateLogLikelihood(self):
  target_events_1 = [0, 1, 0]
  target_events_2 = [1, 2, 2]
  target_event_sequences = [target_events_1, target_events_2]
  softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
             [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
  p = self.enc.evaluate_log_likelihood(target_event_sequences, softmax)
  self.assertListEqual([np.log(0.5) + np.log(0.3),
                        np.log(0.4) + np.log(0.6)], p)
def testEventsToInput(self)
Expand source code
def testEventsToInput(self):
  control_events = [1, 1, 1, 0, 0]
  target_events = [0, 1, 0, 2, 0]
  self.assertEqual(
      [0.0, 1.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(control_events, target_events, 0))
  self.assertEqual(
      [0.0, 1.0, 0.0, 1.0, 0.0],
      self.enc.events_to_input(control_events, target_events, 1))
  self.assertEqual(
      [1.0, 0.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(control_events, target_events, 2))
  self.assertEqual(
      [1.0, 0.0, 0.0, 0.0, 1.0],
      self.enc.events_to_input(control_events, target_events, 3))
def testEventsToLabel(self)
Expand source code
def testEventsToLabel(self):
  target_events = [0, 1, 0, 2, 0]
  self.assertEqual(0, self.enc.events_to_label(target_events, 0))
  self.assertEqual(1, self.enc.events_to_label(target_events, 1))
  self.assertEqual(0, self.enc.events_to_label(target_events, 2))
  self.assertEqual(2, self.enc.events_to_label(target_events, 3))
  self.assertEqual(0, self.enc.events_to_label(target_events, 4))
def testExtendEventSequences(self)
Expand source code
def testExtendEventSequences(self):
  target_events_1 = [0]
  target_events_2 = [0]
  target_events_3 = [0]
  target_event_sequences = [target_events_1, target_events_2, target_events_3]
  softmax = np.array(
      [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]])
  self.enc.extend_event_sequences(target_event_sequences, softmax)
  self.assertListEqual(list(target_events_1), [0, 2])
  self.assertListEqual(list(target_events_2), [0, 0])
  self.assertListEqual(list(target_events_3), [0, 1])
def testGetInputsBatch(self)
Expand source code
def testGetInputsBatch(self):
  control_event_sequences = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]
  target_event_sequences = [[0, 1, 0, 2], [0, 1]]
  expected_inputs_1 = [[0.0, 1.0, 1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0, 1.0, 0.0],
                       [1.0, 0.0, 1.0, 0.0, 0.0],
                       [1.0, 0.0, 0.0, 0.0, 1.0]]
  expected_inputs_2 = [[0.0, 1.0, 1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0, 1.0, 0.0]]
  expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
  expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                      expected_inputs_2[-1:]]
  self.assertListEqual(
      expected_full_length_inputs_batch,
      self.enc.get_inputs_batch(
          control_event_sequences, target_event_sequences, True))
  self.assertListEqual(
      expected_last_event_inputs_batch,
      self.enc.get_inputs_batch(
          control_event_sequences, target_event_sequences))
def testInputSize(self)
Expand source code
def testInputSize(self):
  self.assertEqual(5, self.enc.input_size)
def testNumClasses(self)
Expand source code
def testNumClasses(self):
  self.assertEqual(3, self.enc.num_classes)
class LookbackEventSequenceEncoderDecoderTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class LookbackEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)), [1, 2], 2)

  def testInputSize(self):
    self.assertEqual(13, self.enc.input_size)

  def testNumClasses(self):
    self.assertEqual(5, self.enc.num_classes)

  def testEventsToInput(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
                      1.0, -1.0, 0.0, 0.0],
                     self.enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0,
                      -1.0, 1.0, 0.0, 0.0],
                     self.enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0,
                      1.0, 1.0, 0.0, 1.0],
                     self.enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,
                      -1.0, -1.0, 0.0, 0.0],
                     self.enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
                      1.0, -1.0, 0.0, 1.0],
                     self.enc.events_to_input(events, 4))

  def testEventsToLabel(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(4, self.enc.events_to_label(events, 0))
    self.assertEqual(1, self.enc.events_to_label(events, 1))
    self.assertEqual(4, self.enc.events_to_label(events, 2))
    self.assertEqual(2, self.enc.events_to_label(events, 3))
    self.assertEqual(4, self.enc.events_to_label(events, 4))

  def testClassIndexToEvent(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(3, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(4, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(3, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(4, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(3, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(4, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(3, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(4, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
    self.assertEqual(0, self.enc.class_index_to_event(3, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(4, events[:5]))

  def testLabelsToNumSteps(self):
    labels = [0, 1, 0, 2, 0]
    self.assertEqual(3, self.enc.labels_to_num_steps(labels))

    labels = [0, 1, 3, 2, 4]
    self.assertEqual(5, self.enc.labels_to_num_steps(labels))

  def testEmptyLookback(self):
    enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3), [], 2)
    self.assertEqual(5, enc.input_size)
    self.assertEqual(3, enc.num_classes)

    events = [0, 1, 0, 2, 0]

    self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                     enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0],
                     enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0],
                     enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0],
                     enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                     enc.events_to_input(events, 4))

    self.assertEqual(0, enc.events_to_label(events, 0))
    self.assertEqual(1, enc.events_to_label(events, 1))
    self.assertEqual(0, enc.events_to_label(events, 2))
    self.assertEqual(2, enc.events_to_label(events, 3))
    self.assertEqual(0, enc.events_to_label(events, 4))

    self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
      testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)), [1, 2], 2)
def testClassIndexToEvent(self)
Expand source code
def testClassIndexToEvent(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
  self.assertEqual(0, self.enc.class_index_to_event(3, events[:1]))
  self.assertEqual(0, self.enc.class_index_to_event(4, events[:1]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
  self.assertEqual(1, self.enc.class_index_to_event(3, events[:2]))
  self.assertEqual(0, self.enc.class_index_to_event(4, events[:2]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
  self.assertEqual(0, self.enc.class_index_to_event(3, events[:3]))
  self.assertEqual(1, self.enc.class_index_to_event(4, events[:3]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
  self.assertEqual(2, self.enc.class_index_to_event(3, events[:4]))
  self.assertEqual(0, self.enc.class_index_to_event(4, events[:4]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
  self.assertEqual(0, self.enc.class_index_to_event(3, events[:5]))
  self.assertEqual(2, self.enc.class_index_to_event(4, events[:5]))
def testEmptyLookback(self)
Expand source code
def testEmptyLookback(self):
  enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
      testing_lib.TrivialOneHotEncoding(3), [], 2)
  self.assertEqual(5, enc.input_size)
  self.assertEqual(3, enc.num_classes)

  events = [0, 1, 0, 2, 0]

  self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                   enc.events_to_input(events, 0))
  self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0],
                   enc.events_to_input(events, 1))
  self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0],
                   enc.events_to_input(events, 2))
  self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0],
                   enc.events_to_input(events, 3))
  self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                   enc.events_to_input(events, 4))

  self.assertEqual(0, enc.events_to_label(events, 0))
  self.assertEqual(1, enc.events_to_label(events, 1))
  self.assertEqual(0, enc.events_to_label(events, 2))
  self.assertEqual(2, enc.events_to_label(events, 3))
  self.assertEqual(0, enc.events_to_label(events, 4))

  self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
  self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
  self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
  self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
def testEventsToInput(self)
Expand source code
def testEventsToInput(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
                    1.0, -1.0, 0.0, 0.0],
                   self.enc.events_to_input(events, 0))
  self.assertEqual([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0,
                    -1.0, 1.0, 0.0, 0.0],
                   self.enc.events_to_input(events, 1))
  self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0,
                    1.0, 1.0, 0.0, 1.0],
                   self.enc.events_to_input(events, 2))
  self.assertEqual([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,
                    -1.0, -1.0, 0.0, 0.0],
                   self.enc.events_to_input(events, 3))
  self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
                    1.0, -1.0, 0.0, 1.0],
                   self.enc.events_to_input(events, 4))
def testEventsToLabel(self)
Expand source code
def testEventsToLabel(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual(4, self.enc.events_to_label(events, 0))
  self.assertEqual(1, self.enc.events_to_label(events, 1))
  self.assertEqual(4, self.enc.events_to_label(events, 2))
  self.assertEqual(2, self.enc.events_to_label(events, 3))
  self.assertEqual(4, self.enc.events_to_label(events, 4))
def testInputSize(self)
Expand source code
def testInputSize(self):
  self.assertEqual(13, self.enc.input_size)
def testLabelsToNumSteps(self)
Expand source code
def testLabelsToNumSteps(self):
  labels = [0, 1, 0, 2, 0]
  self.assertEqual(3, self.enc.labels_to_num_steps(labels))

  labels = [0, 1, 3, 2, 4]
  self.assertEqual(5, self.enc.labels_to_num_steps(labels))
def testNumClasses(self)
Expand source code
def testNumClasses(self):
  self.assertEqual(5, self.enc.num_classes)
class MultipleEventSequenceEncoderTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class MultipleEventSequenceEncoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.MultipleEventSequenceEncoder([
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(2)),
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(3))])

  def testInputSize(self):
    self.assertEqual(5, self.enc.input_size)

  def testEventsToInput(self):
    events = [(1, 0), (1, 1), (1, 0), (0, 2), (0, 0)]
    self.assertEqual(
        [0.0, 1.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 0))
    self.assertEqual(
        [0.0, 1.0, 0.0, 1.0, 0.0],
        self.enc.events_to_input(events, 1))
    self.assertEqual(
        [0.0, 1.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 2))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0, 1.0],
        self.enc.events_to_input(events, 3))
    self.assertEqual(
        [1.0, 0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 4))

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.enc = encoder_decoder.MultipleEventSequenceEncoder([
      encoder_decoder.OneHotEventSequenceEncoderDecoder(
          testing_lib.TrivialOneHotEncoding(2)),
      encoder_decoder.OneHotEventSequenceEncoderDecoder(
          testing_lib.TrivialOneHotEncoding(3))])
def testEventsToInput(self)
Expand source code
def testEventsToInput(self):
  events = [(1, 0), (1, 1), (1, 0), (0, 2), (0, 0)]
  self.assertEqual(
      [0.0, 1.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(events, 0))
  self.assertEqual(
      [0.0, 1.0, 0.0, 1.0, 0.0],
      self.enc.events_to_input(events, 1))
  self.assertEqual(
      [0.0, 1.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(events, 2))
  self.assertEqual(
      [1.0, 0.0, 0.0, 0.0, 1.0],
      self.enc.events_to_input(events, 3))
  self.assertEqual(
      [1.0, 0.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(events, 4))
def testInputSize(self)
Expand source code
def testInputSize(self):
  self.assertEqual(5, self.enc.input_size)
class OneHotEventSequenceEncoderDecoderTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class OneHotEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))

  def testInputSize(self):
    self.assertEqual(3, self.enc.input_size)

  def testNumClasses(self):
    self.assertEqual(3, self.enc.num_classes)

  def testEventsToInput(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0], self.enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0], self.enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))

  def testEventsToLabel(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.events_to_label(events, 0))
    self.assertEqual(1, self.enc.events_to_label(events, 1))
    self.assertEqual(0, self.enc.events_to_label(events, 2))
    self.assertEqual(2, self.enc.events_to_label(events, 3))
    self.assertEqual(0, self.enc.events_to_label(events, 4))

  def testClassIndexToEvent(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual(0, self.enc.class_index_to_event(0, events))
    self.assertEqual(1, self.enc.class_index_to_event(1, events))
    self.assertEqual(2, self.enc.class_index_to_event(2, events))

  def testLabelsToNumSteps(self):
    labels = [0, 1, 0, 2, 0]
    self.assertEqual(3, self.enc.labels_to_num_steps(labels))

  def testEncode(self):
    events = [0, 1, 0, 2, 0]
    inputs, labels = self.enc.encode(events)
    expected_inputs = [[1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0],
                       [1.0, 0.0, 0.0],
                       [0.0, 0.0, 1.0]]
    expected_labels = [1, 0, 2, 0]
    self.assertEqual(inputs, expected_inputs)
    self.assertEqual(labels, expected_labels)

  def testGetInputsBatch(self):
    event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
    expected_inputs_1 = [[1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0],
                         [1.0, 0.0, 0.0],
                         [0.0, 0.0, 1.0],
                         [1.0, 0.0, 0.0]]
    expected_inputs_2 = [[1.0, 0.0, 0.0],
                         [0.0, 1.0, 0.0],
                         [0.0, 0.0, 1.0]]
    expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
    expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                        expected_inputs_2[-1:]]
    self.assertListEqual(
        expected_full_length_inputs_batch,
        self.enc.get_inputs_batch(event_sequences, True))
    self.assertListEqual(
        expected_last_event_inputs_batch,
        self.enc.get_inputs_batch(event_sequences))

  def testExtendEventSequences(self):
    events1 = [0]
    events2 = [0]
    events3 = [0]
    event_sequences = [events1, events2, events3]
    softmax = [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]
    self.enc.extend_event_sequences(event_sequences, softmax)
    self.assertListEqual(list(events1), [0, 2])
    self.assertListEqual(list(events2), [0, 0])
    self.assertListEqual(list(events3), [0, 1])

  def testEvaluateLogLikelihood(self):
    events1 = [0, 1, 0]
    events2 = [1, 2, 2]
    event_sequences = [events1, events2]
    softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
               [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
    p = self.enc.evaluate_log_likelihood(event_sequences, softmax)
    self.assertListEqual([np.log(0.5) + np.log(0.3),
                          np.log(0.4) + np.log(0.6)], p)

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder(
      testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))
def testClassIndexToEvent(self)
Expand source code
def testClassIndexToEvent(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual(0, self.enc.class_index_to_event(0, events))
  self.assertEqual(1, self.enc.class_index_to_event(1, events))
  self.assertEqual(2, self.enc.class_index_to_event(2, events))
def testEncode(self)
Expand source code
def testEncode(self):
  events = [0, 1, 0, 2, 0]
  inputs, labels = self.enc.encode(events)
  expected_inputs = [[1.0, 0.0, 0.0],
                     [0.0, 1.0, 0.0],
                     [1.0, 0.0, 0.0],
                     [0.0, 0.0, 1.0]]
  expected_labels = [1, 0, 2, 0]
  self.assertEqual(inputs, expected_inputs)
  self.assertEqual(labels, expected_labels)
def testEvaluateLogLikelihood(self)
Expand source code
def testEvaluateLogLikelihood(self):
  events1 = [0, 1, 0]
  events2 = [1, 2, 2]
  event_sequences = [events1, events2]
  softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
             [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
  p = self.enc.evaluate_log_likelihood(event_sequences, softmax)
  self.assertListEqual([np.log(0.5) + np.log(0.3),
                        np.log(0.4) + np.log(0.6)], p)
def testEventsToInput(self)
Expand source code
def testEventsToInput(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 0))
  self.assertEqual([0.0, 1.0, 0.0], self.enc.events_to_input(events, 1))
  self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 2))
  self.assertEqual([0.0, 0.0, 1.0], self.enc.events_to_input(events, 3))
  self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
def testEventsToLabel(self)
Expand source code
def testEventsToLabel(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual(0, self.enc.events_to_label(events, 0))
  self.assertEqual(1, self.enc.events_to_label(events, 1))
  self.assertEqual(0, self.enc.events_to_label(events, 2))
  self.assertEqual(2, self.enc.events_to_label(events, 3))
  self.assertEqual(0, self.enc.events_to_label(events, 4))
def testExtendEventSequences(self)
Expand source code
def testExtendEventSequences(self):
  events1 = [0]
  events2 = [0]
  events3 = [0]
  event_sequences = [events1, events2, events3]
  softmax = [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]
  self.enc.extend_event_sequences(event_sequences, softmax)
  self.assertListEqual(list(events1), [0, 2])
  self.assertListEqual(list(events2), [0, 0])
  self.assertListEqual(list(events3), [0, 1])
def testGetInputsBatch(self)
Expand source code
def testGetInputsBatch(self):
  event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
  expected_inputs_1 = [[1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0],
                       [1.0, 0.0, 0.0],
                       [0.0, 0.0, 1.0],
                       [1.0, 0.0, 0.0]]
  expected_inputs_2 = [[1.0, 0.0, 0.0],
                       [0.0, 1.0, 0.0],
                       [0.0, 0.0, 1.0]]
  expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
  expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                      expected_inputs_2[-1:]]
  self.assertListEqual(
      expected_full_length_inputs_batch,
      self.enc.get_inputs_batch(event_sequences, True))
  self.assertListEqual(
      expected_last_event_inputs_batch,
      self.enc.get_inputs_batch(event_sequences))
def testInputSize(self)
Expand source code
def testInputSize(self):
  self.assertEqual(3, self.enc.input_size)
def testLabelsToNumSteps(self)
Expand source code
def testLabelsToNumSteps(self):
  labels = [0, 1, 0, 2, 0]
  self.assertEqual(3, self.enc.labels_to_num_steps(labels))
def testNumClasses(self)
Expand source code
def testNumClasses(self):
  self.assertEqual(3, self.enc.num_classes)
class OneHotIndexEventSequenceEncoderDecoderTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class OneHotIndexEventSequenceEncoderDecoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.OneHotIndexEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))

  def testInputSize(self):
    self.assertEqual(1, self.enc.input_size)

  def testInputDepth(self):
    self.assertEqual(3, self.enc.input_depth)

  def testEventsToInput(self):
    events = [0, 1, 0, 2, 0]
    self.assertEqual([0], self.enc.events_to_input(events, 0))
    self.assertEqual([1], self.enc.events_to_input(events, 1))
    self.assertEqual([0], self.enc.events_to_input(events, 2))
    self.assertEqual([2], self.enc.events_to_input(events, 3))
    self.assertEqual([0], self.enc.events_to_input(events, 4))

  def testEncode(self):
    events = [0, 1, 0, 2, 0]
    inputs, labels = self.enc.encode(events)
    expected_inputs = [[0], [1], [0], [2]]
    expected_labels = [1, 0, 2, 0]
    self.assertEqual(inputs, expected_inputs)
    self.assertEqual(labels, expected_labels)

  def testGetInputsBatch(self):
    event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
    expected_inputs_1 = [[0], [1], [0], [2], [0]]
    expected_inputs_2 = [[0], [1], [2]]
    expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
    expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                        expected_inputs_2[-1:]]
    self.assertListEqual(
        expected_full_length_inputs_batch,
        self.enc.get_inputs_batch(event_sequences, True))
    self.assertListEqual(
        expected_last_event_inputs_batch,
        self.enc.get_inputs_batch(event_sequences))

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.enc = encoder_decoder.OneHotIndexEventSequenceEncoderDecoder(
      testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))
def testEncode(self)
Expand source code
def testEncode(self):
  events = [0, 1, 0, 2, 0]
  inputs, labels = self.enc.encode(events)
  expected_inputs = [[0], [1], [0], [2]]
  expected_labels = [1, 0, 2, 0]
  self.assertEqual(inputs, expected_inputs)
  self.assertEqual(labels, expected_labels)
def testEventsToInput(self)
Expand source code
def testEventsToInput(self):
  events = [0, 1, 0, 2, 0]
  self.assertEqual([0], self.enc.events_to_input(events, 0))
  self.assertEqual([1], self.enc.events_to_input(events, 1))
  self.assertEqual([0], self.enc.events_to_input(events, 2))
  self.assertEqual([2], self.enc.events_to_input(events, 3))
  self.assertEqual([0], self.enc.events_to_input(events, 4))
def testGetInputsBatch(self)
Expand source code
def testGetInputsBatch(self):
  event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
  expected_inputs_1 = [[0], [1], [0], [2], [0]]
  expected_inputs_2 = [[0], [1], [2]]
  expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
  expected_last_event_inputs_batch = [expected_inputs_1[-1:],
                                      expected_inputs_2[-1:]]
  self.assertListEqual(
      expected_full_length_inputs_batch,
      self.enc.get_inputs_batch(event_sequences, True))
  self.assertListEqual(
      expected_last_event_inputs_batch,
      self.enc.get_inputs_batch(event_sequences))
def testInputDepth(self)
Expand source code
def testInputDepth(self):
  self.assertEqual(3, self.enc.input_depth)
def testInputSize(self)
Expand source code
def testInputSize(self):
  self.assertEqual(1, self.enc.input_size)
class OptionalEventSequenceEncoderTest (*args, **kwargs)

Extension of unittest.TestCase providing more power.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class OptionalEventSequenceEncoderTest(absltest.TestCase):

  def setUp(self):
    super().setUp()
    self.enc = encoder_decoder.OptionalEventSequenceEncoder(
        encoder_decoder.OneHotEventSequenceEncoderDecoder(
            testing_lib.TrivialOneHotEncoding(3)))

  def testInputSize(self):
    self.assertEqual(4, self.enc.input_size)

  def testEventsToInput(self):
    events = [(False, 0), (False, 1), (False, 0), (True, 2), (True, 0)]
    self.assertEqual(
        [0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 0))
    self.assertEqual(
        [0.0, 0.0, 1.0, 0.0],
        self.enc.events_to_input(events, 1))
    self.assertEqual(
        [0.0, 1.0, 0.0, 0.0],
        self.enc.events_to_input(events, 2))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0],
        self.enc.events_to_input(events, 3))
    self.assertEqual(
        [1.0, 0.0, 0.0, 0.0],
        self.enc.events_to_input(events, 4))

Ancestors

  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Methods

def setUp(self)

Hook method for setting up the test fixture before exercising it.

Expand source code
def setUp(self):
  super().setUp()
  self.enc = encoder_decoder.OptionalEventSequenceEncoder(
      encoder_decoder.OneHotEventSequenceEncoderDecoder(
          testing_lib.TrivialOneHotEncoding(3)))
def testEventsToInput(self)
Expand source code
def testEventsToInput(self):
  events = [(False, 0), (False, 1), (False, 0), (True, 2), (True, 0)]
  self.assertEqual(
      [0.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(events, 0))
  self.assertEqual(
      [0.0, 0.0, 1.0, 0.0],
      self.enc.events_to_input(events, 1))
  self.assertEqual(
      [0.0, 1.0, 0.0, 0.0],
      self.enc.events_to_input(events, 2))
  self.assertEqual(
      [1.0, 0.0, 0.0, 0.0],
      self.enc.events_to_input(events, 3))
  self.assertEqual(
      [1.0, 0.0, 0.0, 0.0],
      self.enc.events_to_input(events, 4))
def testInputSize(self)
Expand source code
def testInputSize(self):
  self.assertEqual(4, self.enc.input_size)