File size: 6,276 Bytes
4fa4c7b
 
01af800
4fa4c7b
 
acd9509
01af800
4fa4c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b49edc5
4fa4c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2375d69
b49edc5
2375d69
 
01af800
6f92fa3
acd9509
 
 
 
 
 
 
 
6f92fa3
2375d69
 
4fa4c7b
 
01af800
2375d69
4fa4c7b
 
 
2375d69
01af800
 
2375d69
 
01af800
2375d69
 
 
 
01af800
2375d69
4fa4c7b
 
 
 
 
 
 
2375d69
6f92fa3
 
 
 
4fa4c7b
 
6f92fa3
 
 
 
 
 
 
 
 
 
 
 
 
acd9509
 
 
 
 
 
 
 
 
4fa4c7b
 
 
 
acd9509
 
 
 
 
 
6f92fa3
4fa4c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import requests
import json
import torch
import os
from datetime import datetime, timedelta
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

class GigaChat:
    def __init__(self, auth_file='auth_token.json'):
        # url = "https://ngw.devices.sberbank.ru:9443/api/v2/oauth"
        self.auth_url = "https://api.mlrnd.ru/api/v2/oauth"

        # url = "https://gigachat.devices.sberbank.ru/api/v1/chat/completions"
        self.gen_url = "https://api.mlrnd.ru/api/v1/chat/completions"

        # payload='scope=GIGACHAT_API_CORP'
        self.payload='scope=API_v1'

        self.auth_file = None

        if self.auth_file is None or not os.path.isfile(auth_file):
            self.gen_giga_token(auth_file)

    @classmethod
    def get_giga(cls, auth_file='auth_token.json'):
        print('got giga')
        return cls(auth_file)

    def gen_giga_token(self, auth_file):
        headers = {
            'Content-Type': 'application/x-www-form-urlencoded',
            'Accept': 'application/json',
            'RqUID': '1b519047-0ee9-4b63-8599-e5ffc9c77e72',
            'Authorization': os.getenv('GIGACHAT_API_TOKEN')
        }

        response = requests.request(
            "POST",
            self.auth_url,
            headers=headers,
            data=self.payload,
            verify=False
            )

        with open(auth_file, 'w') as f:
            json.dump(json.loads(response.text), f, ensure_ascii=False)


    def get_text(self, content, auth_token=None, params=None):
        if params is None:
            params = dict()

        payload = json.dumps(
            {
                 "model": "Test_model",
                 "messages": content,
                 "temperature": params.get("temperature") if params.get("temperature") else 1,
                 "top_p": params.get("top_p") if params.get("top_p") else 0.9,
                 "n": params.get("n") if params.get("n") else 1,
                 "stream": False,
                 "max_tokens": params.get("max_tokens") if params.get("max_tokens") else 512,
                 "repetition_penalty":  params.get("repetition_penalty") if params.get("repetition_penalty") else 1
            }
        )
        headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json',
            'Authorization': f'Bearer {auth_token}'
        }

        response = requests.request("POST", self.gen_url, headers=headers, data=payload, verify=False)

        return json.loads(response.text)


def get_tinyllama():
    print('got llama')
    tinyllama = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16, device_map="auto")
    return tinyllama

def get_qwen2ins1b():
    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen2-1.5B-Instruct",
        torch_dtype="auto",
        device_map="auto"
        )
    
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
    return {'model': model, 'tokenizer': tokenizer}

def response_tinyllama(
        model=None,
        messages=None,
        params=None
        ):
    
    if params is None:
        params = dict()

    messages_dict = [
        {
            "role": "system",
            "content": "You are a friendly and helpful chatbot",
        }
    ]
    for step in messages:
        messages_dict.append({'role': 'user', 'content': step[0]})
        if len(step) >= 2:
            messages_dict.append({'role': 'assistant', 'content': step[1]})

    prompt = model.tokenizer.apply_chat_template(messages_dict, tokenize=False, add_generation_prompt=True)
    outputs = model(
        prompt,
        max_new_tokens = params.get("max_tokens") if params.get("max_tokens") else 512,
        temperature = params.get("temperature") if params.get("temperature") else 1,
        top_p = params.get("top_p") if params.get("top_p") else 0.9,
        repetition_penalty = params.get("repetition_penalty") if params.get("repetition_penalty") else 1
        )

    return outputs[0]['generated_text'].split('<|assistant|>')[1].strip()

def response_qwen2ins1b(
        model=None,
        messages=None,
        params=None
        ):
    
    messages_dict = [
        {
            "role": "system",
            "content": "You are a friendly and helpful chatbot",
        }
    ]
    for step in messages:
        messages_dict.append({'role': 'user', 'content': step[0]})
        if len(step) >= 2:
            messages_dict.append({'role': 'assistant', 'content': step[1]})

    text = model['tokenizer'].apply_chat_template(
        messages_dict,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = model['tokenizer']([text], return_tensors="pt")

    generated_ids = model['model'].generate(
        model_inputs.input_ids,
        max_new_tokens = params.get("max_tokens") if params.get("max_tokens") else 512,
        temperature = params.get("temperature") if params.get("temperature") else 1,
        top_p = params.get("top_p") if params.get("top_p") else 0.9,
        repetition_penalty = params.get("repetition_penalty") if params.get("repetition_penalty") else 1
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = model['tokenizer'].batch_decode(generated_ids, skip_special_tokens=True)[0]

    return response # outputs[0]['generated_text'] #.split('<|assistant|>')[1].strip()

def response_gigachat(
        model=None,
        messages=None,
        model_params=None
        ): # content=None, auth_file=None

    with open(model.auth_file) as f:
        auth_token = json.load(f)
    
    if datetime.fromtimestamp(auth_token['expires_at']/1000) <= datetime.now() - timedelta(seconds=60):
        model.gen_giga_token(model.auth_file)
        with open(model.auth_file) as f:
            auth_token = json.load(f)

    content = []
    for step in messages:
        content.append({'role': 'user', 'content': step[0]})
        if len(step) >= 2:
            content.append({'role': 'assistant', 'content': step[1]})

    resp = model.get_text(content, auth_token['access_token'], model_params)
    
    return resp["choices"][0]["message"]["content"]