GoidaAlignment commited on
Commit
74de454
·
verified ·
1 Parent(s): af7afe0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -17
app.py CHANGED
@@ -1,30 +1,67 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
 
4
  # Загрузка токенизатора и модели
5
- model_name = "GoidaAlignment/GOIDA-0.5B" # Укажите путь к вашей модели
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
8
 
9
- def generate_response(prompt):
10
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
11
- outputs = model.generate(inputs["input_ids"], max_length=200, num_return_sequences=1)
12
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
- return response
 
 
 
 
 
 
 
 
14
 
15
- # Интерфейс Gradio
16
- with gr.Blocks() as demo:
17
- gr.Markdown("# Введите запрос, и модель ответит.")
 
 
 
 
 
18
 
19
- with gr.Row():
20
- with gr.Column():
21
- prompt_input = gr.Textbox(label="Ваш запрос", lines=4, placeholder="Введите текст")
22
- with gr.Column():
23
- output = gr.Textbox(label="Ответ модели", lines=6, interactive=False)
 
 
 
24
 
25
- submit_button = gr.Button("Сгенерировать")
26
- submit_button.click(generate_response, inputs=prompt_input, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Запуск приложения
29
  if __name__ == "__main__":
30
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
  # Загрузка токенизатора и модели
6
+ model_name = "GoidaAlignment/GOIDA-0.5B" # Замените на вашу модель
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # Шаблонная функция для форматирования диалога
12
+ def apply_chat_template(chat, add_generation_prompt=True):
13
+ formatted_chat = ""
14
+ for message in chat:
15
+ role = message["role"]
16
+ content = message["content"]
17
+ if role == "user":
18
+ formatted_chat += f"User: {content}\n"
19
+ elif role == "assistant":
20
+ formatted_chat += f"Assistant: {content}\n"
21
+ if add_generation_prompt:
22
+ formatted_chat += "Assistant: "
23
+ return formatted_chat
24
 
25
+ # Функция генерации ответа
26
+ def generate_response(user_input, chat_history):
27
+ chat_history.append({"role": "user", "content": user_input})
28
+ formatted_chat = apply_chat_template(chat_history, add_generation_prompt=True)
29
+
30
+ # Токенизация
31
+ inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False)
32
+ inputs = {key: tensor.to(model.device) for key, tensor in inputs.items()}
33
 
34
+ # Генерация
35
+ outputs = model.generate(
36
+ **inputs,
37
+ max_new_tokens=64,
38
+ temperature=0.7,
39
+ top_p=0.9,
40
+ do_sample=True
41
+ )
42
 
43
+ # Декодирование
44
+ decoded_output = tokenizer.decode(outputs[0][inputs["input_ids"].size(1):], skip_special_tokens=True)
45
+ chat_history.append({"role": "assistant", "content": decoded_output})
46
+
47
+ return decoded_output, chat_history
48
+
49
+ # Интерфейс Gradio
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("# Chatbot на основе модели ГОЙДАААА\nВзаимодействуйте с языковой моделью.")
52
+
53
+ chatbot = gr.Chatbot()
54
+ user_input = gr.Textbox(placeholder="Введите ваше сообщение...")
55
+ clear = gr.Button("Очистить чат")
56
+
57
+ chat_history = gr.State([]) # Состояние для хранения истории чата
58
+
59
+ user_input.submit(
60
+ generate_response,
61
+ [user_input, chat_history],
62
+ [chatbot, chat_history]
63
+ )
64
+ clear.click(lambda: ([], []), None, [chatbot, chat_history])
65
 
 
66
  if __name__ == "__main__":
67
  demo.launch()