File size: 3,091 Bytes
5d125d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright 2024 AniMemory Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from safetensors.torch import load_file
from transformers.models.t5.configuration_t5 import T5Config
from transformers.models.t5.modeling_t5 import T5Stack


class AniMemoryT5(torch.nn.Module):
    def __init__(self, config: T5Config, embed_tokens=None):
        super().__init__()
        self.encoder = T5Stack(config, embed_tokens)
        self.embed_tokens_encoder = torch.nn.Embedding(250002, 4096, padding_idx=1)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path,
        subfolder="",
        embed_tokens=None,
        emb_name="weights.safetensors",
        torch_dtype=torch.float16,
    ):
        cls.dtype = torch_dtype
        config = T5Stack.config_class.from_pretrained(
            pretrained_model_name_or_path, subfolder=subfolder
        )
        model = cls(config=config, embed_tokens=embed_tokens)
        model.encoder = T5Stack.from_pretrained(
            pretrained_model_name_or_path, subfolder=subfolder
        )
        embed_tokens_encoder_path = load_file(
            os.path.join(pretrained_model_name_or_path, subfolder, emb_name)
        )
        model.embed_tokens_encoder.load_state_dict(embed_tokens_encoder_path)
        model.encoder.to(torch_dtype)
        model.embed_tokens_encoder.to(torch_dtype)
        return model

    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
            *args, **kwargs
        )
        super(AniMemoryT5, self).to(*args, **kwargs)
        self.dtype = dtype if dtype is not None else self.dtype
        self.device = device if device is not None else self.device
        return self

    def make_attn_mask(self, attn_mask):
        seq_len = attn_mask.shape[1]
        query = attn_mask.unsqueeze(1).float()
        attn_mask = (
            query.repeat([1, seq_len, 1]).unsqueeze(1).repeat([1, self.num_head, 1, 1])
        )
        attn_mask = attn_mask.view([-1, seq_len, seq_len])
        return attn_mask

    def forward(self, text, attention_mask):
        embeddings = self.embed_tokens_encoder(text)
        encoder_outputs = self.encoder(
            inputs_embeds=embeddings,
            attention_mask=attention_mask,
            output_hidden_states=True,
        )
        hidden_states = encoder_outputs.hidden_states[-2]
        hidden_states = self.encoder.final_layer_norm(hidden_states)
        return hidden_states, hidden_states