Module note_seq.encoder_decoder_test
Tests for encoder_decoder.
Expand source code
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for encoder_decoder."""
from absl.testing import absltest
from note_seq import encoder_decoder
from note_seq import testing_lib
import numpy as np
class OneHotEventSequenceEncoderDecoderTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))
def testInputSize(self):
self.assertEqual(3, self.enc.input_size)
def testNumClasses(self):
self.assertEqual(3, self.enc.num_classes)
def testEventsToInput(self):
events = [0, 1, 0, 2, 0]
self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 0))
self.assertEqual([0.0, 1.0, 0.0], self.enc.events_to_input(events, 1))
self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 2))
self.assertEqual([0.0, 0.0, 1.0], self.enc.events_to_input(events, 3))
self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
def testEventsToLabel(self):
events = [0, 1, 0, 2, 0]
self.assertEqual(0, self.enc.events_to_label(events, 0))
self.assertEqual(1, self.enc.events_to_label(events, 1))
self.assertEqual(0, self.enc.events_to_label(events, 2))
self.assertEqual(2, self.enc.events_to_label(events, 3))
self.assertEqual(0, self.enc.events_to_label(events, 4))
def testClassIndexToEvent(self):
events = [0, 1, 0, 2, 0]
self.assertEqual(0, self.enc.class_index_to_event(0, events))
self.assertEqual(1, self.enc.class_index_to_event(1, events))
self.assertEqual(2, self.enc.class_index_to_event(2, events))
def testLabelsToNumSteps(self):
labels = [0, 1, 0, 2, 0]
self.assertEqual(3, self.enc.labels_to_num_steps(labels))
def testEncode(self):
events = [0, 1, 0, 2, 0]
inputs, labels = self.enc.encode(events)
expected_inputs = [[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 1.0]]
expected_labels = [1, 0, 2, 0]
self.assertEqual(inputs, expected_inputs)
self.assertEqual(labels, expected_labels)
def testGetInputsBatch(self):
event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
expected_inputs_1 = [[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0]]
expected_inputs_2 = [[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]]
expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
expected_last_event_inputs_batch = [expected_inputs_1[-1:],
expected_inputs_2[-1:]]
self.assertListEqual(
expected_full_length_inputs_batch,
self.enc.get_inputs_batch(event_sequences, True))
self.assertListEqual(
expected_last_event_inputs_batch,
self.enc.get_inputs_batch(event_sequences))
def testExtendEventSequences(self):
events1 = [0]
events2 = [0]
events3 = [0]
event_sequences = [events1, events2, events3]
softmax = [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]
self.enc.extend_event_sequences(event_sequences, softmax)
self.assertListEqual(list(events1), [0, 2])
self.assertListEqual(list(events2), [0, 0])
self.assertListEqual(list(events3), [0, 1])
def testEvaluateLogLikelihood(self):
events1 = [0, 1, 0]
events2 = [1, 2, 2]
event_sequences = [events1, events2]
softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
[[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
p = self.enc.evaluate_log_likelihood(event_sequences, softmax)
self.assertListEqual([np.log(0.5) + np.log(0.3),
np.log(0.4) + np.log(0.6)], p)
class OneHotIndexEventSequenceEncoderDecoderTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.enc = encoder_decoder.OneHotIndexEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))
def testInputSize(self):
self.assertEqual(1, self.enc.input_size)
def testInputDepth(self):
self.assertEqual(3, self.enc.input_depth)
def testEventsToInput(self):
events = [0, 1, 0, 2, 0]
self.assertEqual([0], self.enc.events_to_input(events, 0))
self.assertEqual([1], self.enc.events_to_input(events, 1))
self.assertEqual([0], self.enc.events_to_input(events, 2))
self.assertEqual([2], self.enc.events_to_input(events, 3))
self.assertEqual([0], self.enc.events_to_input(events, 4))
def testEncode(self):
events = [0, 1, 0, 2, 0]
inputs, labels = self.enc.encode(events)
expected_inputs = [[0], [1], [0], [2]]
expected_labels = [1, 0, 2, 0]
self.assertEqual(inputs, expected_inputs)
self.assertEqual(labels, expected_labels)
def testGetInputsBatch(self):
event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]]
expected_inputs_1 = [[0], [1], [0], [2], [0]]
expected_inputs_2 = [[0], [1], [2]]
expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
expected_last_event_inputs_batch = [expected_inputs_1[-1:],
expected_inputs_2[-1:]]
self.assertListEqual(
expected_full_length_inputs_batch,
self.enc.get_inputs_batch(event_sequences, True))
self.assertListEqual(
expected_last_event_inputs_batch,
self.enc.get_inputs_batch(event_sequences))
class LookbackEventSequenceEncoderDecoderTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)), [1, 2], 2)
def testInputSize(self):
self.assertEqual(13, self.enc.input_size)
def testNumClasses(self):
self.assertEqual(5, self.enc.num_classes)
def testEventsToInput(self):
events = [0, 1, 0, 2, 0]
self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0,
1.0, -1.0, 0.0, 0.0],
self.enc.events_to_input(events, 0))
self.assertEqual([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0,
-1.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(events, 1))
self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0,
1.0, 1.0, 0.0, 1.0],
self.enc.events_to_input(events, 2))
self.assertEqual([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,
-1.0, -1.0, 0.0, 0.0],
self.enc.events_to_input(events, 3))
self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
1.0, -1.0, 0.0, 1.0],
self.enc.events_to_input(events, 4))
def testEventsToLabel(self):
events = [0, 1, 0, 2, 0]
self.assertEqual(4, self.enc.events_to_label(events, 0))
self.assertEqual(1, self.enc.events_to_label(events, 1))
self.assertEqual(4, self.enc.events_to_label(events, 2))
self.assertEqual(2, self.enc.events_to_label(events, 3))
self.assertEqual(4, self.enc.events_to_label(events, 4))
def testClassIndexToEvent(self):
events = [0, 1, 0, 2, 0]
self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
self.assertEqual(0, self.enc.class_index_to_event(3, events[:1]))
self.assertEqual(0, self.enc.class_index_to_event(4, events[:1]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
self.assertEqual(1, self.enc.class_index_to_event(3, events[:2]))
self.assertEqual(0, self.enc.class_index_to_event(4, events[:2]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
self.assertEqual(0, self.enc.class_index_to_event(3, events[:3]))
self.assertEqual(1, self.enc.class_index_to_event(4, events[:3]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
self.assertEqual(2, self.enc.class_index_to_event(3, events[:4]))
self.assertEqual(0, self.enc.class_index_to_event(4, events[:4]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
self.assertEqual(0, self.enc.class_index_to_event(3, events[:5]))
self.assertEqual(2, self.enc.class_index_to_event(4, events[:5]))
def testLabelsToNumSteps(self):
labels = [0, 1, 0, 2, 0]
self.assertEqual(3, self.enc.labels_to_num_steps(labels))
labels = [0, 1, 3, 2, 4]
self.assertEqual(5, self.enc.labels_to_num_steps(labels))
def testEmptyLookback(self):
enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3), [], 2)
self.assertEqual(5, enc.input_size)
self.assertEqual(3, enc.num_classes)
events = [0, 1, 0, 2, 0]
self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
enc.events_to_input(events, 0))
self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0],
enc.events_to_input(events, 1))
self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0],
enc.events_to_input(events, 2))
self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0],
enc.events_to_input(events, 3))
self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
enc.events_to_input(events, 4))
self.assertEqual(0, enc.events_to_label(events, 0))
self.assertEqual(1, enc.events_to_label(events, 1))
self.assertEqual(0, enc.events_to_label(events, 2))
self.assertEqual(2, enc.events_to_label(events, 3))
self.assertEqual(0, enc.events_to_label(events, 4))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
class ConditionalEventSequenceEncoderDecoderTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder(
encoder_decoder.OneHotEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(2)),
encoder_decoder.OneHotEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3)))
def testInputSize(self):
self.assertEqual(5, self.enc.input_size)
def testNumClasses(self):
self.assertEqual(3, self.enc.num_classes)
def testEventsToInput(self):
control_events = [1, 1, 1, 0, 0]
target_events = [0, 1, 0, 2, 0]
self.assertEqual(
[0.0, 1.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(control_events, target_events, 0))
self.assertEqual(
[0.0, 1.0, 0.0, 1.0, 0.0],
self.enc.events_to_input(control_events, target_events, 1))
self.assertEqual(
[1.0, 0.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(control_events, target_events, 2))
self.assertEqual(
[1.0, 0.0, 0.0, 0.0, 1.0],
self.enc.events_to_input(control_events, target_events, 3))
def testEventsToLabel(self):
target_events = [0, 1, 0, 2, 0]
self.assertEqual(0, self.enc.events_to_label(target_events, 0))
self.assertEqual(1, self.enc.events_to_label(target_events, 1))
self.assertEqual(0, self.enc.events_to_label(target_events, 2))
self.assertEqual(2, self.enc.events_to_label(target_events, 3))
self.assertEqual(0, self.enc.events_to_label(target_events, 4))
def testClassIndexToEvent(self):
target_events = [0, 1, 0, 2, 0]
self.assertEqual(0, self.enc.class_index_to_event(0, target_events))
self.assertEqual(1, self.enc.class_index_to_event(1, target_events))
self.assertEqual(2, self.enc.class_index_to_event(2, target_events))
def testEncode(self):
control_events = [1, 1, 1, 0, 0]
target_events = [0, 1, 0, 2, 0]
inputs, labels = self.enc.encode(control_events, target_events)
expected_inputs = [[0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0, 0.0],
[1.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 1.0]]
expected_labels = [1, 0, 2, 0]
self.assertEqual(inputs, expected_inputs)
self.assertEqual(labels, expected_labels)
def testGetInputsBatch(self):
control_event_sequences = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]]
target_event_sequences = [[0, 1, 0, 2], [0, 1]]
expected_inputs_1 = [[0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0, 0.0],
[1.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 1.0]]
expected_inputs_2 = [[0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 1.0, 0.0]]
expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2]
expected_last_event_inputs_batch = [expected_inputs_1[-1:],
expected_inputs_2[-1:]]
self.assertListEqual(
expected_full_length_inputs_batch,
self.enc.get_inputs_batch(
control_event_sequences, target_event_sequences, True))
self.assertListEqual(
expected_last_event_inputs_batch,
self.enc.get_inputs_batch(
control_event_sequences, target_event_sequences))
def testExtendEventSequences(self):
target_events_1 = [0]
target_events_2 = [0]
target_events_3 = [0]
target_event_sequences = [target_events_1, target_events_2, target_events_3]
softmax = np.array(
[[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]])
self.enc.extend_event_sequences(target_event_sequences, softmax)
self.assertListEqual(list(target_events_1), [0, 2])
self.assertListEqual(list(target_events_2), [0, 0])
self.assertListEqual(list(target_events_3), [0, 1])
def testEvaluateLogLikelihood(self):
target_events_1 = [0, 1, 0]
target_events_2 = [1, 2, 2]
target_event_sequences = [target_events_1, target_events_2]
softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]],
[[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]]
p = self.enc.evaluate_log_likelihood(target_event_sequences, softmax)
self.assertListEqual([np.log(0.5) + np.log(0.3),
np.log(0.4) + np.log(0.6)], p)
class OptionalEventSequenceEncoderTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.enc = encoder_decoder.OptionalEventSequenceEncoder(
encoder_decoder.OneHotEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3)))
def testInputSize(self):
self.assertEqual(4, self.enc.input_size)
def testEventsToInput(self):
events = [(False, 0), (False, 1), (False, 0), (True, 2), (True, 0)]
self.assertEqual(
[0.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(events, 0))
self.assertEqual(
[0.0, 0.0, 1.0, 0.0],
self.enc.events_to_input(events, 1))
self.assertEqual(
[0.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(events, 2))
self.assertEqual(
[1.0, 0.0, 0.0, 0.0],
self.enc.events_to_input(events, 3))
self.assertEqual(
[1.0, 0.0, 0.0, 0.0],
self.enc.events_to_input(events, 4))
class MultipleEventSequenceEncoderTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.enc = encoder_decoder.MultipleEventSequenceEncoder([
encoder_decoder.OneHotEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(2)),
encoder_decoder.OneHotEventSequenceEncoderDecoder(
testing_lib.TrivialOneHotEncoding(3))])
def testInputSize(self):
self.assertEqual(5, self.enc.input_size)
def testEventsToInput(self):
events = [(1, 0), (1, 1), (1, 0), (0, 2), (0, 0)]
self.assertEqual(
[0.0, 1.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(events, 0))
self.assertEqual(
[0.0, 1.0, 0.0, 1.0, 0.0],
self.enc.events_to_input(events, 1))
self.assertEqual(
[0.0, 1.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(events, 2))
self.assertEqual(
[1.0, 0.0, 0.0, 0.0, 1.0],
self.enc.events_to_input(events, 3))
self.assertEqual(
[1.0, 0.0, 1.0, 0.0, 0.0],
self.enc.events_to_input(events, 4))
if __name__ == '__main__':
absltest.main()
Classes
class ConditionalEventSequenceEncoderDecoderTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class ConditionalEventSequenceEncoderDecoderTest(absltest.TestCase): def setUp(self): super().setUp() self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder( encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(2)), encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3))) def testInputSize(self): self.assertEqual(5, self.enc.input_size) def testNumClasses(self): self.assertEqual(3, self.enc.num_classes) def testEventsToInput(self): control_events = [1, 1, 1, 0, 0] target_events = [0, 1, 0, 2, 0] self.assertEqual( [0.0, 1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(control_events, target_events, 0)) self.assertEqual( [0.0, 1.0, 0.0, 1.0, 0.0], self.enc.events_to_input(control_events, target_events, 1)) self.assertEqual( [1.0, 0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(control_events, target_events, 2)) self.assertEqual( [1.0, 0.0, 0.0, 0.0, 1.0], self.enc.events_to_input(control_events, target_events, 3)) def testEventsToLabel(self): target_events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.events_to_label(target_events, 0)) self.assertEqual(1, self.enc.events_to_label(target_events, 1)) self.assertEqual(0, self.enc.events_to_label(target_events, 2)) self.assertEqual(2, self.enc.events_to_label(target_events, 3)) self.assertEqual(0, self.enc.events_to_label(target_events, 4)) def testClassIndexToEvent(self): target_events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.class_index_to_event(0, target_events)) self.assertEqual(1, self.enc.class_index_to_event(1, target_events)) self.assertEqual(2, self.enc.class_index_to_event(2, target_events)) def testEncode(self): control_events = [1, 1, 1, 0, 0] target_events = [0, 1, 0, 2, 0] inputs, labels = self.enc.encode(control_events, target_events) expected_inputs = [[0.0, 1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0]] expected_labels = [1, 0, 2, 0] self.assertEqual(inputs, expected_inputs) self.assertEqual(labels, expected_labels) def testGetInputsBatch(self): control_event_sequences = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]] target_event_sequences = [[0, 1, 0, 2], [0, 1]] expected_inputs_1 = [[0.0, 1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0]] expected_inputs_2 = [[0.0, 1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 0.0]] expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2] expected_last_event_inputs_batch = [expected_inputs_1[-1:], expected_inputs_2[-1:]] self.assertListEqual( expected_full_length_inputs_batch, self.enc.get_inputs_batch( control_event_sequences, target_event_sequences, True)) self.assertListEqual( expected_last_event_inputs_batch, self.enc.get_inputs_batch( control_event_sequences, target_event_sequences)) def testExtendEventSequences(self): target_events_1 = [0] target_events_2 = [0] target_events_3 = [0] target_event_sequences = [target_events_1, target_events_2, target_events_3] softmax = np.array( [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]) self.enc.extend_event_sequences(target_event_sequences, softmax) self.assertListEqual(list(target_events_1), [0, 2]) self.assertListEqual(list(target_events_2), [0, 0]) self.assertListEqual(list(target_events_3), [0, 1]) def testEvaluateLogLikelihood(self): target_events_1 = [0, 1, 0] target_events_2 = [1, 2, 2] target_event_sequences = [target_events_1, target_events_2] softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]], [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]] p = self.enc.evaluate_log_likelihood(target_event_sequences, softmax) self.assertListEqual([np.log(0.5) + np.log(0.3), np.log(0.4) + np.log(0.6)], p)
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): super().setUp() self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder( encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(2)), encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3)))
def testClassIndexToEvent(self)
-
Expand source code
def testClassIndexToEvent(self): target_events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.class_index_to_event(0, target_events)) self.assertEqual(1, self.enc.class_index_to_event(1, target_events)) self.assertEqual(2, self.enc.class_index_to_event(2, target_events))
def testEncode(self)
-
Expand source code
def testEncode(self): control_events = [1, 1, 1, 0, 0] target_events = [0, 1, 0, 2, 0] inputs, labels = self.enc.encode(control_events, target_events) expected_inputs = [[0.0, 1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0]] expected_labels = [1, 0, 2, 0] self.assertEqual(inputs, expected_inputs) self.assertEqual(labels, expected_labels)
def testEvaluateLogLikelihood(self)
-
Expand source code
def testEvaluateLogLikelihood(self): target_events_1 = [0, 1, 0] target_events_2 = [1, 2, 2] target_event_sequences = [target_events_1, target_events_2] softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]], [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]] p = self.enc.evaluate_log_likelihood(target_event_sequences, softmax) self.assertListEqual([np.log(0.5) + np.log(0.3), np.log(0.4) + np.log(0.6)], p)
def testEventsToInput(self)
-
Expand source code
def testEventsToInput(self): control_events = [1, 1, 1, 0, 0] target_events = [0, 1, 0, 2, 0] self.assertEqual( [0.0, 1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(control_events, target_events, 0)) self.assertEqual( [0.0, 1.0, 0.0, 1.0, 0.0], self.enc.events_to_input(control_events, target_events, 1)) self.assertEqual( [1.0, 0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(control_events, target_events, 2)) self.assertEqual( [1.0, 0.0, 0.0, 0.0, 1.0], self.enc.events_to_input(control_events, target_events, 3))
def testEventsToLabel(self)
-
Expand source code
def testEventsToLabel(self): target_events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.events_to_label(target_events, 0)) self.assertEqual(1, self.enc.events_to_label(target_events, 1)) self.assertEqual(0, self.enc.events_to_label(target_events, 2)) self.assertEqual(2, self.enc.events_to_label(target_events, 3)) self.assertEqual(0, self.enc.events_to_label(target_events, 4))
def testExtendEventSequences(self)
-
Expand source code
def testExtendEventSequences(self): target_events_1 = [0] target_events_2 = [0] target_events_3 = [0] target_event_sequences = [target_events_1, target_events_2, target_events_3] softmax = np.array( [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]) self.enc.extend_event_sequences(target_event_sequences, softmax) self.assertListEqual(list(target_events_1), [0, 2]) self.assertListEqual(list(target_events_2), [0, 0]) self.assertListEqual(list(target_events_3), [0, 1])
def testGetInputsBatch(self)
-
Expand source code
def testGetInputsBatch(self): control_event_sequences = [[1, 1, 1, 0, 0], [1, 1, 1, 0, 0]] target_event_sequences = [[0, 1, 0, 2], [0, 1]] expected_inputs_1 = [[0.0, 1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0]] expected_inputs_2 = [[0.0, 1.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 1.0, 0.0]] expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2] expected_last_event_inputs_batch = [expected_inputs_1[-1:], expected_inputs_2[-1:]] self.assertListEqual( expected_full_length_inputs_batch, self.enc.get_inputs_batch( control_event_sequences, target_event_sequences, True)) self.assertListEqual( expected_last_event_inputs_batch, self.enc.get_inputs_batch( control_event_sequences, target_event_sequences))
def testInputSize(self)
-
Expand source code
def testInputSize(self): self.assertEqual(5, self.enc.input_size)
def testNumClasses(self)
-
Expand source code
def testNumClasses(self): self.assertEqual(3, self.enc.num_classes)
class LookbackEventSequenceEncoderDecoderTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class LookbackEventSequenceEncoderDecoderTest(absltest.TestCase): def setUp(self): super().setUp() self.enc = encoder_decoder.LookbackEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)), [1, 2], 2) def testInputSize(self): self.assertEqual(13, self.enc.input_size) def testNumClasses(self): self.assertEqual(5, self.enc.num_classes) def testEventsToInput(self): events = [0, 1, 0, 2, 0] self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], self.enc.events_to_input(events, 2)) self.assertEqual([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0], self.enc.events_to_input(events, 3)) self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, -1.0, 0.0, 1.0], self.enc.events_to_input(events, 4)) def testEventsToLabel(self): events = [0, 1, 0, 2, 0] self.assertEqual(4, self.enc.events_to_label(events, 0)) self.assertEqual(1, self.enc.events_to_label(events, 1)) self.assertEqual(4, self.enc.events_to_label(events, 2)) self.assertEqual(2, self.enc.events_to_label(events, 3)) self.assertEqual(4, self.enc.events_to_label(events, 4)) def testClassIndexToEvent(self): events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.class_index_to_event(0, events[:1])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:1])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(3, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(4, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:2])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:2])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:2])) self.assertEqual(1, self.enc.class_index_to_event(3, events[:2])) self.assertEqual(0, self.enc.class_index_to_event(4, events[:2])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:3])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:3])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:3])) self.assertEqual(0, self.enc.class_index_to_event(3, events[:3])) self.assertEqual(1, self.enc.class_index_to_event(4, events[:3])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:4])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:4])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:4])) self.assertEqual(2, self.enc.class_index_to_event(3, events[:4])) self.assertEqual(0, self.enc.class_index_to_event(4, events[:4])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:5])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:5])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:5])) self.assertEqual(0, self.enc.class_index_to_event(3, events[:5])) self.assertEqual(2, self.enc.class_index_to_event(4, events[:5])) def testLabelsToNumSteps(self): labels = [0, 1, 0, 2, 0] self.assertEqual(3, self.enc.labels_to_num_steps(labels)) labels = [0, 1, 3, 2, 4] self.assertEqual(5, self.enc.labels_to_num_steps(labels)) def testEmptyLookback(self): enc = encoder_decoder.LookbackEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3), [], 2) self.assertEqual(5, enc.input_size) self.assertEqual(3, enc.num_classes) events = [0, 1, 0, 2, 0] self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0], enc.events_to_input(events, 0)) self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0], enc.events_to_input(events, 1)) self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0], enc.events_to_input(events, 2)) self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0], enc.events_to_input(events, 3)) self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0], enc.events_to_input(events, 4)) self.assertEqual(0, enc.events_to_label(events, 0)) self.assertEqual(1, enc.events_to_label(events, 1)) self.assertEqual(0, enc.events_to_label(events, 2)) self.assertEqual(2, enc.events_to_label(events, 3)) self.assertEqual(0, enc.events_to_label(events, 4)) self.assertEqual(0, self.enc.class_index_to_event(0, events[:1])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:1])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:2])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:2])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:2])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:3])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:3])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:3])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:4])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:4])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:4])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:5])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:5])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): super().setUp() self.enc = encoder_decoder.LookbackEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)), [1, 2], 2)
def testClassIndexToEvent(self)
-
Expand source code
def testClassIndexToEvent(self): events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.class_index_to_event(0, events[:1])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:1])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(3, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(4, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:2])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:2])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:2])) self.assertEqual(1, self.enc.class_index_to_event(3, events[:2])) self.assertEqual(0, self.enc.class_index_to_event(4, events[:2])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:3])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:3])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:3])) self.assertEqual(0, self.enc.class_index_to_event(3, events[:3])) self.assertEqual(1, self.enc.class_index_to_event(4, events[:3])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:4])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:4])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:4])) self.assertEqual(2, self.enc.class_index_to_event(3, events[:4])) self.assertEqual(0, self.enc.class_index_to_event(4, events[:4])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:5])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:5])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:5])) self.assertEqual(0, self.enc.class_index_to_event(3, events[:5])) self.assertEqual(2, self.enc.class_index_to_event(4, events[:5]))
def testEmptyLookback(self)
-
Expand source code
def testEmptyLookback(self): enc = encoder_decoder.LookbackEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3), [], 2) self.assertEqual(5, enc.input_size) self.assertEqual(3, enc.num_classes) events = [0, 1, 0, 2, 0] self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0], enc.events_to_input(events, 0)) self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0], enc.events_to_input(events, 1)) self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0], enc.events_to_input(events, 2)) self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0], enc.events_to_input(events, 3)) self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0], enc.events_to_input(events, 4)) self.assertEqual(0, enc.events_to_label(events, 0)) self.assertEqual(1, enc.events_to_label(events, 1)) self.assertEqual(0, enc.events_to_label(events, 2)) self.assertEqual(2, enc.events_to_label(events, 3)) self.assertEqual(0, enc.events_to_label(events, 4)) self.assertEqual(0, self.enc.class_index_to_event(0, events[:1])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:1])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:1])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:2])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:2])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:2])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:3])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:3])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:3])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:4])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:4])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:4])) self.assertEqual(0, self.enc.class_index_to_event(0, events[:5])) self.assertEqual(1, self.enc.class_index_to_event(1, events[:5])) self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
def testEventsToInput(self)
-
Expand source code
def testEventsToInput(self): events = [0, 1, 0, 2, 0] self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0], self.enc.events_to_input(events, 2)) self.assertEqual([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0], self.enc.events_to_input(events, 3)) self.assertEqual([1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, -1.0, 0.0, 1.0], self.enc.events_to_input(events, 4))
def testEventsToLabel(self)
-
Expand source code
def testEventsToLabel(self): events = [0, 1, 0, 2, 0] self.assertEqual(4, self.enc.events_to_label(events, 0)) self.assertEqual(1, self.enc.events_to_label(events, 1)) self.assertEqual(4, self.enc.events_to_label(events, 2)) self.assertEqual(2, self.enc.events_to_label(events, 3)) self.assertEqual(4, self.enc.events_to_label(events, 4))
def testInputSize(self)
-
Expand source code
def testInputSize(self): self.assertEqual(13, self.enc.input_size)
def testLabelsToNumSteps(self)
-
Expand source code
def testLabelsToNumSteps(self): labels = [0, 1, 0, 2, 0] self.assertEqual(3, self.enc.labels_to_num_steps(labels)) labels = [0, 1, 3, 2, 4] self.assertEqual(5, self.enc.labels_to_num_steps(labels))
def testNumClasses(self)
-
Expand source code
def testNumClasses(self): self.assertEqual(5, self.enc.num_classes)
class MultipleEventSequenceEncoderTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class MultipleEventSequenceEncoderTest(absltest.TestCase): def setUp(self): super().setUp() self.enc = encoder_decoder.MultipleEventSequenceEncoder([ encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(2)), encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3))]) def testInputSize(self): self.assertEqual(5, self.enc.input_size) def testEventsToInput(self): events = [(1, 0), (1, 1), (1, 0), (0, 2), (0, 0)] self.assertEqual( [0.0, 1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual( [0.0, 1.0, 0.0, 1.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual( [0.0, 1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 2)) self.assertEqual( [1.0, 0.0, 0.0, 0.0, 1.0], self.enc.events_to_input(events, 3)) self.assertEqual( [1.0, 0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): super().setUp() self.enc = encoder_decoder.MultipleEventSequenceEncoder([ encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(2)), encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3))])
def testEventsToInput(self)
-
Expand source code
def testEventsToInput(self): events = [(1, 0), (1, 1), (1, 0), (0, 2), (0, 0)] self.assertEqual( [0.0, 1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual( [0.0, 1.0, 0.0, 1.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual( [0.0, 1.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 2)) self.assertEqual( [1.0, 0.0, 0.0, 0.0, 1.0], self.enc.events_to_input(events, 3)) self.assertEqual( [1.0, 0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
def testInputSize(self)
-
Expand source code
def testInputSize(self): self.assertEqual(5, self.enc.input_size)
class OneHotEventSequenceEncoderDecoderTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class OneHotEventSequenceEncoderDecoderTest(absltest.TestCase): def setUp(self): super().setUp() self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3, num_steps=range(3))) def testInputSize(self): self.assertEqual(3, self.enc.input_size) def testNumClasses(self): self.assertEqual(3, self.enc.num_classes) def testEventsToInput(self): events = [0, 1, 0, 2, 0] self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual([0.0, 1.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 2)) self.assertEqual([0.0, 0.0, 1.0], self.enc.events_to_input(events, 3)) self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 4)) def testEventsToLabel(self): events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.events_to_label(events, 0)) self.assertEqual(1, self.enc.events_to_label(events, 1)) self.assertEqual(0, self.enc.events_to_label(events, 2)) self.assertEqual(2, self.enc.events_to_label(events, 3)) self.assertEqual(0, self.enc.events_to_label(events, 4)) def testClassIndexToEvent(self): events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.class_index_to_event(0, events)) self.assertEqual(1, self.enc.class_index_to_event(1, events)) self.assertEqual(2, self.enc.class_index_to_event(2, events)) def testLabelsToNumSteps(self): labels = [0, 1, 0, 2, 0] self.assertEqual(3, self.enc.labels_to_num_steps(labels)) def testEncode(self): events = [0, 1, 0, 2, 0] inputs, labels = self.enc.encode(events) expected_inputs = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]] expected_labels = [1, 0, 2, 0] self.assertEqual(inputs, expected_inputs) self.assertEqual(labels, expected_labels) def testGetInputsBatch(self): event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]] expected_inputs_1 = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]] expected_inputs_2 = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2] expected_last_event_inputs_batch = [expected_inputs_1[-1:], expected_inputs_2[-1:]] self.assertListEqual( expected_full_length_inputs_batch, self.enc.get_inputs_batch(event_sequences, True)) self.assertListEqual( expected_last_event_inputs_batch, self.enc.get_inputs_batch(event_sequences)) def testExtendEventSequences(self): events1 = [0] events2 = [0] events3 = [0] event_sequences = [events1, events2, events3] softmax = [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]] self.enc.extend_event_sequences(event_sequences, softmax) self.assertListEqual(list(events1), [0, 2]) self.assertListEqual(list(events2), [0, 0]) self.assertListEqual(list(events3), [0, 1]) def testEvaluateLogLikelihood(self): events1 = [0, 1, 0] events2 = [1, 2, 2] event_sequences = [events1, events2] softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]], [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]] p = self.enc.evaluate_log_likelihood(event_sequences, softmax) self.assertListEqual([np.log(0.5) + np.log(0.3), np.log(0.4) + np.log(0.6)], p)
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): super().setUp() self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))
def testClassIndexToEvent(self)
-
Expand source code
def testClassIndexToEvent(self): events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.class_index_to_event(0, events)) self.assertEqual(1, self.enc.class_index_to_event(1, events)) self.assertEqual(2, self.enc.class_index_to_event(2, events))
def testEncode(self)
-
Expand source code
def testEncode(self): events = [0, 1, 0, 2, 0] inputs, labels = self.enc.encode(events) expected_inputs = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]] expected_labels = [1, 0, 2, 0] self.assertEqual(inputs, expected_inputs) self.assertEqual(labels, expected_labels)
def testEvaluateLogLikelihood(self)
-
Expand source code
def testEvaluateLogLikelihood(self): events1 = [0, 1, 0] events2 = [1, 2, 2] event_sequences = [events1, events2] softmax = [[[0.0, 0.5, 0.5], [0.3, 0.4, 0.3]], [[0.0, 0.6, 0.4], [0.0, 0.4, 0.6]]] p = self.enc.evaluate_log_likelihood(event_sequences, softmax) self.assertListEqual([np.log(0.5) + np.log(0.3), np.log(0.4) + np.log(0.6)], p)
def testEventsToInput(self)
-
Expand source code
def testEventsToInput(self): events = [0, 1, 0, 2, 0] self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual([0.0, 1.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 2)) self.assertEqual([0.0, 0.0, 1.0], self.enc.events_to_input(events, 3)) self.assertEqual([1.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
def testEventsToLabel(self)
-
Expand source code
def testEventsToLabel(self): events = [0, 1, 0, 2, 0] self.assertEqual(0, self.enc.events_to_label(events, 0)) self.assertEqual(1, self.enc.events_to_label(events, 1)) self.assertEqual(0, self.enc.events_to_label(events, 2)) self.assertEqual(2, self.enc.events_to_label(events, 3)) self.assertEqual(0, self.enc.events_to_label(events, 4))
def testExtendEventSequences(self)
-
Expand source code
def testExtendEventSequences(self): events1 = [0] events2 = [0] events3 = [0] event_sequences = [events1, events2, events3] softmax = [[[0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]] self.enc.extend_event_sequences(event_sequences, softmax) self.assertListEqual(list(events1), [0, 2]) self.assertListEqual(list(events2), [0, 0]) self.assertListEqual(list(events3), [0, 1])
def testGetInputsBatch(self)
-
Expand source code
def testGetInputsBatch(self): event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]] expected_inputs_1 = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]] expected_inputs_2 = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2] expected_last_event_inputs_batch = [expected_inputs_1[-1:], expected_inputs_2[-1:]] self.assertListEqual( expected_full_length_inputs_batch, self.enc.get_inputs_batch(event_sequences, True)) self.assertListEqual( expected_last_event_inputs_batch, self.enc.get_inputs_batch(event_sequences))
def testInputSize(self)
-
Expand source code
def testInputSize(self): self.assertEqual(3, self.enc.input_size)
def testLabelsToNumSteps(self)
-
Expand source code
def testLabelsToNumSteps(self): labels = [0, 1, 0, 2, 0] self.assertEqual(3, self.enc.labels_to_num_steps(labels))
def testNumClasses(self)
-
Expand source code
def testNumClasses(self): self.assertEqual(3, self.enc.num_classes)
class OneHotIndexEventSequenceEncoderDecoderTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class OneHotIndexEventSequenceEncoderDecoderTest(absltest.TestCase): def setUp(self): super().setUp() self.enc = encoder_decoder.OneHotIndexEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3, num_steps=range(3))) def testInputSize(self): self.assertEqual(1, self.enc.input_size) def testInputDepth(self): self.assertEqual(3, self.enc.input_depth) def testEventsToInput(self): events = [0, 1, 0, 2, 0] self.assertEqual([0], self.enc.events_to_input(events, 0)) self.assertEqual([1], self.enc.events_to_input(events, 1)) self.assertEqual([0], self.enc.events_to_input(events, 2)) self.assertEqual([2], self.enc.events_to_input(events, 3)) self.assertEqual([0], self.enc.events_to_input(events, 4)) def testEncode(self): events = [0, 1, 0, 2, 0] inputs, labels = self.enc.encode(events) expected_inputs = [[0], [1], [0], [2]] expected_labels = [1, 0, 2, 0] self.assertEqual(inputs, expected_inputs) self.assertEqual(labels, expected_labels) def testGetInputsBatch(self): event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]] expected_inputs_1 = [[0], [1], [0], [2], [0]] expected_inputs_2 = [[0], [1], [2]] expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2] expected_last_event_inputs_batch = [expected_inputs_1[-1:], expected_inputs_2[-1:]] self.assertListEqual( expected_full_length_inputs_batch, self.enc.get_inputs_batch(event_sequences, True)) self.assertListEqual( expected_last_event_inputs_batch, self.enc.get_inputs_batch(event_sequences))
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): super().setUp() self.enc = encoder_decoder.OneHotIndexEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))
def testEncode(self)
-
Expand source code
def testEncode(self): events = [0, 1, 0, 2, 0] inputs, labels = self.enc.encode(events) expected_inputs = [[0], [1], [0], [2]] expected_labels = [1, 0, 2, 0] self.assertEqual(inputs, expected_inputs) self.assertEqual(labels, expected_labels)
def testEventsToInput(self)
-
Expand source code
def testEventsToInput(self): events = [0, 1, 0, 2, 0] self.assertEqual([0], self.enc.events_to_input(events, 0)) self.assertEqual([1], self.enc.events_to_input(events, 1)) self.assertEqual([0], self.enc.events_to_input(events, 2)) self.assertEqual([2], self.enc.events_to_input(events, 3)) self.assertEqual([0], self.enc.events_to_input(events, 4))
def testGetInputsBatch(self)
-
Expand source code
def testGetInputsBatch(self): event_sequences = [[0, 1, 0, 2, 0], [0, 1, 2]] expected_inputs_1 = [[0], [1], [0], [2], [0]] expected_inputs_2 = [[0], [1], [2]] expected_full_length_inputs_batch = [expected_inputs_1, expected_inputs_2] expected_last_event_inputs_batch = [expected_inputs_1[-1:], expected_inputs_2[-1:]] self.assertListEqual( expected_full_length_inputs_batch, self.enc.get_inputs_batch(event_sequences, True)) self.assertListEqual( expected_last_event_inputs_batch, self.enc.get_inputs_batch(event_sequences))
def testInputDepth(self)
-
Expand source code
def testInputDepth(self): self.assertEqual(3, self.enc.input_depth)
def testInputSize(self)
-
Expand source code
def testInputSize(self): self.assertEqual(1, self.enc.input_size)
class OptionalEventSequenceEncoderTest (*args, **kwargs)
-
Extension of unittest.TestCase providing more power.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class OptionalEventSequenceEncoderTest(absltest.TestCase): def setUp(self): super().setUp() self.enc = encoder_decoder.OptionalEventSequenceEncoder( encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3))) def testInputSize(self): self.assertEqual(4, self.enc.input_size) def testEventsToInput(self): events = [(False, 0), (False, 1), (False, 0), (True, 2), (True, 0)] self.assertEqual( [0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual( [0.0, 0.0, 1.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual( [0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 2)) self.assertEqual( [1.0, 0.0, 0.0, 0.0], self.enc.events_to_input(events, 3)) self.assertEqual( [1.0, 0.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
Ancestors
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def setUp(self)
-
Hook method for setting up the test fixture before exercising it.
Expand source code
def setUp(self): super().setUp() self.enc = encoder_decoder.OptionalEventSequenceEncoder( encoder_decoder.OneHotEventSequenceEncoderDecoder( testing_lib.TrivialOneHotEncoding(3)))
def testEventsToInput(self)
-
Expand source code
def testEventsToInput(self): events = [(False, 0), (False, 1), (False, 0), (True, 2), (True, 0)] self.assertEqual( [0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 0)) self.assertEqual( [0.0, 0.0, 1.0, 0.0], self.enc.events_to_input(events, 1)) self.assertEqual( [0.0, 1.0, 0.0, 0.0], self.enc.events_to_input(events, 2)) self.assertEqual( [1.0, 0.0, 0.0, 0.0], self.enc.events_to_input(events, 3)) self.assertEqual( [1.0, 0.0, 0.0, 0.0], self.enc.events_to_input(events, 4))
def testInputSize(self)
-
Expand source code
def testInputSize(self): self.assertEqual(4, self.enc.input_size)