import unittest from transcribe.strategy import TranscriptChunk, TranscriptToken, SplitMode class TestTranscriptChunk(unittest.TestCase): def setUp(self): self.tokens = [ TranscriptToken(text="Hello", t0=0, t1=100), TranscriptToken(text=",", t0=100, t1=200), TranscriptToken(text="world", t0=200, t1=300), TranscriptToken(text=".", t0=300, t1=400), ] self.chunk = TranscriptChunk(items=self.tokens, separator=" ") def test_split_by_punctuation(self): chunks = self.chunk.split_by(SplitMode.PUNCTUATION) self.assertEqual(len(chunks), 3) self.assertEqual(chunks[0].join(), "Hello ,") self.assertEqual(chunks[1].join(), "world .") self.assertEqual(chunks[2].join(), "") def test_get_split_first_rest(self): first, rest = self.chunk.get_split_first_rest(SplitMode.PUNCTUATION) self.assertEqual(first.join(), "Hello ,") self.assertEqual(len(rest), 2) self.assertEqual(rest[0].join(), "world .") self.assertEqual(rest[1].join(), "") def test_punctuation_numbers(self): self.assertEqual(self.chunk.puncation_numbers(), 2) def test_length(self): self.assertEqual(self.chunk.length(), 4) def test_join(self): self.assertEqual(self.chunk.join(), "Hello , world .") def test_compare(self): other_chunk = TranscriptChunk(items=[ TranscriptToken(text="Hello", t0=0, t1=100), TranscriptToken(text="!", t0=100, t1=200), ], separator=" ") similarity = self.chunk.compare(other_chunk) self.assertTrue(0 < similarity < 1) def test_has_punctuation(self): self.assertTrue(self.chunk.has_punctuation()) def test_get_buffer_index(self): # t1 = 400 -> index = 400 / 100 * 16000 = 64000 self.assertEqual(self.chunk.get_buffer_index(), 64000) def test_is_end_sentence(self): self.assertTrue(self.chunk.is_end_sentence()) if __name__ == '__main__': unittest.main()