File size: 5,284 Bytes
ca4fc4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import unittest
from Andromeda.model import Andromeda


class TestAndromeda(unittest.TestCase):
    def setUp(self):
        self.model = Andromeda()

    def test_initialization(self):
        self.assertIsNotNone(self.model.andromeda, "Transformer is not initialized.")
        self.assertIsNotNone(self.model.decoder, "AutoregressiveWrapper is not initialized.")

    def test_forward_pass(self):
        input_tokens = torch.randint(0, 50432, (1, 8192))
        output = self.model(input_tokens)
        self.assertIsInstance(output, torch.Tensor, "Output is not a PyTorch tensor.")
        self.assertEqual(output.shape[0], input_tokens.shape[0], "Output batch size does not match input.")

    def test_error_handling(self):
        with self.assertRaises(Exception):
            self.model.forward(None)

    def test_model_parameters(self):
        self.assertEqual(self.model.Andromeda.num_tokens, 50432, "Number of tokens is not correctly set.")
        self.assertEqual(self.model.Andromeda.max_seq_len, 8192, "Max sequence length is not correctly set.")

    def test_model_output(self):
        input_tokens = torch.randint(0, 50432, (1, 8192))
        output1 = self.model(input_tokens)
        output2 = self.model(input_tokens)
        self.assertTrue(torch.allclose(output1, output2), "Model does not produce consistent output.")


class TestAndromedaExtended(unittest.TestCase):
    def setUp(self):
        self.model = Andromeda()

    def test_input_size(self):
        for seq_len in [512, 1024, 2048, 4096]:
            input_tokens = torch.randint(0, 50432, (1, seq_len))
            output = self.model(input_tokens)
            self.assertEqual(output.shape[1], seq_len, f"Output sequence length does not match input for seq_len={seq_len}.")

    def test_batch_size(self):
        for batch_size in [2, 4, 8, 16]:
            input_tokens = torch.randint(0, 50432, (batch_size, 8192))
            output = self.model(input_tokens)
            self.assertEqual(output.shape[0], batch_size, f"Output batch size does not match input for batch_size={batch_size}.")

    def test_token_range(self):
        for token in [0, 50431]:
            input_tokens = torch.full((1, 8192), fill_value=token)
            output = self.model(input_tokens)
            self.assertIsInstance(output, torch.Tensor, f"Output is not a PyTorch tensor for token={token}.")

    def test_model_depth(self):
        for depth in [16, 32, 64]:
            model = Andromeda(depth=depth)
            self.assertEqual(model.Andromeda.attn_layers.depth, depth, f"Model depth is not correctly set for depth={depth}.")

    def test_model_dim(self):
        for dim in [1280, 2560, 5120]:
            model = Andromeda(dim=dim)
            self.assertEqual(model.Andromeda.attn_layers.dim, dim, f"Model dimension is not correctly set for dim={dim}.")

    def test_model_heads(self):
        for heads in [12, 24, 48]:
            model = Andromeda(heads=heads)
            self.assertEqual(model.Andromeda.attn_layers.heads, heads, f"Number of heads is not correctly set for heads={heads}.")

    def test_model_dim_head(self):
        for dim_head in [64, 128, 256]:
            model = Andromeda(dim_head=dim_head)
            self.assertEqual(model.Andromeda.attn_layers.dim_head, dim_head, f"Head dimension is not correctly set for dim_head={dim_head}.")

    def test_model_alibi_num_heads(self):
        for alibi_num_heads in [6, 12, 24]:
            model = Andromeda(alibi_num_heads=alibi_num_heads)
            self.assertEqual(model.Andromeda.attn_layers.alibi_num_heads, alibi_num_heads, f"Number of alibi heads is not correctly set for alibi_num_heads={alibi_num_heads}.")

    def test_model_shift_tokens(self):
        for shift_tokens in [0, 1, 2]:
            model = Andromeda(shift_tokens=shift_tokens)
            self.assertEqual(model.Andromeda.attn_layers.shift_tokens, shift_tokens, f"Number of shift tokens is not correctly set for shift_tokens={shift_tokens}.")

    def test_model_use_abs_pos_emb(self):
        for use_abs_pos_emb in [True, False]:
            model = Andromeda(use_abs_pos_emb=use_abs_pos_emb)
            self.assertEqual(model.Andromeda.use_abs_pos_emb, use_abs_pos_emb, f"Use absolute position embedding flag is not correctly set for use_abs_pos_emb={use_abs_pos_emb}.")

    def test_model_alibi_pos_bias(self):
        for alibi_pos_bias in [True, False]:
            model = Andromeda(alibi_pos_bias=alibi_pos_bias)
            self.assertEqual(model.Andromeda.attn_layers.alibi_pos_bias, alibi_pos_bias, f"Alibi position bias flag is not correctly set for alibi_pos_bias={alibi_pos_bias}.")

    def test_model_rotary_xpos(self):
        for rotary_xpos in [True, False]:
            model = Andromeda(rotary_xpos=rotary_xpos)
            self.assertEqual(model.Andromeda.attn_layers.rotary_xpos, rotary_xpos, f"Rotary position flag is not correctly set for rotary_xpos={rotary_xpos}.")

    def test_model_attn_flash(self):
        for attn_flash in [True, False]:
            model = Andromeda(attn_flash=attn_flash)
            self.assertEqual(model.Andromeda.attn_layers.attn_flash, attn_flash, f"Attention flash flag is not correctly set for attn_flash={attn_flash}")

if __name__ == '__main__':
    unittest.main()