File size: 4,085 Bytes
162b68f
44d180e
 
2b167f5
09f135e
 
9779cd8
 
09f135e
2b167f5
a1fddf9
01a2ce5
143b62d
1921336
143b62d
3f40f6e
c9c9f16
e2bb507
143b62d
01a2ce5
1921336
a1fddf9
 
01a2ce5
 
203771e
68c64e4
1921336
 
 
 
 
 
 
 
 
 
 
 
68c64e4
203771e
c5c9ee6
01a2ce5
 
 
 
 
 
 
 
 
 
 
 
 
c89910e
 
5e8ccd5
1921336
 
 
c89910e
 
 
5e8ccd5
c89910e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e8ccd5
c89910e
 
309f86b
c89910e
 
 
 
5e8ccd5
c89910e
 
 
8f783e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch

from huggingface_hub import login
import os 

import logging

login(token = os.getenv('HF_TOKEN'))

class Model(torch.nn.Module):
    number_of_models = 0
    __model_list__ = [
        "Qwen/Qwen2-1.5B-Instruct",
        "lmsys/vicuna-7b-v1.5",
        "google-t5/t5-large",
        "mistralai/Mistral-7B-Instruct-v0.1",
        "meta-llama/Meta-Llama-3.1-8B-Instruct"
    ]

    def __init__(self, model_name="Qwen/Qwen2-1.5B-Instruct") -> None:
        super(Model, self).__init__()
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.name = model_name

        logging.info(f'start loading model {self.name}')

        if model_name == "google-t5/t5-large":
            # For T5 or any other Seq2Seq model
            self.model = AutoModelForSeq2SeqLM.from_pretrained(
                model_name, torch_dtype=torch.bfloat16, device_map="auto"
            )
        else:
            # For GPT-like models or other causal language models
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name, torch_dtype=torch.bfloat16, device_map="auto"
            )

        logging.info(f'Loaded model {self.name}')

        self.model.eval()
        self.update()

    @classmethod
    def update(cls):
        cls.number_of_models += 1

    def return_mode_name(self):
        return self.name
    
    def return_tokenizer(self):
        return self.tokenizer
    
    def return_model(self):
        return self.model
    
    def streaming(self, content_list, temp=0.001, max_length=500, do_sample=True):
        # Convert list of texts to input IDs
        input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)

        # Set up the initial generation parameters
        gen_kwargs = {
            "input_ids": input_ids,
            "do_sample": do_sample,
            "temperature": temp,
            "eos_token_id": self.tokenizer.eos_token_id,
            "max_new_tokens": 1,  # Generate one token at a time
            "return_dict_in_generate": True,
            "output_scores": True
        }

        # Generate and yield tokens one by one
        generated_tokens = 0
        batch_size = input_ids.shape[0]
        active_sequences = torch.arange(batch_size)

        while generated_tokens < max_length and len(active_sequences) > 0:
            with torch.no_grad():
                output = self.model.generate(**gen_kwargs)
            
            next_tokens = output.sequences[:, -1].unsqueeze(-1)
            
            # Yield the newly generated tokens for each sequence in the batch
            for i, token in zip(active_sequences, next_tokens):
                yield i, self.tokenizer.decode(token[0], skip_special_tokens=True)

            # Update input_ids for the next iteration
            gen_kwargs["input_ids"] = torch.cat([gen_kwargs["input_ids"], next_tokens], dim=-1)
            generated_tokens += 1

            # Check for completed sequences
            completed = (next_tokens.squeeze(-1) == self.tokenizer.eos_token_id).nonzero().squeeze(-1)
            active_sequences = torch.tensor([i for i in active_sequences if i not in completed])
            if len(active_sequences) > 0:
                gen_kwargs["input_ids"] = gen_kwargs["input_ids"][active_sequences]


    def gen(self, content_list, temp=0.001, max_length=500, do_sample=True):
        # Convert list of texts to input IDs
        input_ids = self.tokenizer(content_list, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)

        # Non-streaming generation (unchanged)
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=max_length,
            do_sample=do_sample,
            temperature=temp,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.batch_decode(outputs[:, input_ids.shape[1]:], skip_special_tokens=True)