vilarin commited on
Commit
423ddc8
·
verified ·
1 Parent(s): 3975a20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -39
app.py CHANGED
@@ -2,19 +2,16 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
9
- MODEL_LIST = ["mistralai/Ministral-8B-Instruct-2410"]
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL = os.environ.get("MODEL_ID")
12
 
13
- TITLE = "<h1><center>Mistral-Nemo</center></h1>"
14
 
15
  PLACEHOLDER = """
16
  <center>
17
- <p>The Mistral-8B is a pretrained generative text model of 8B parameters trained jointly by Mistral AI.</p>
18
  </center>
19
  """
20
 
@@ -31,14 +28,26 @@ h3 {
31
  }
32
  """
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  device = "cuda" # for GPU usage or "cpu" for CPU usage
35
 
36
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL,
39
- torch_dtype=torch.bfloat16,
40
- device_map="auto",
41
- ignore_mismatched_sizes=True)
42
 
43
  @spaces.GPU()
44
  def stream_chat(
@@ -55,42 +64,31 @@ def stream_chat(
55
 
56
  conversation = []
57
  for prompt, answer in history:
58
- conversation.extend([
59
- {"role": "user", "content": prompt},
60
- {"role": "assistant", "content": answer},
61
- ])
62
-
63
- conversation.append({"role": "user", "content": message})
64
 
65
- input_text=tokenizer.apply_chat_template(conversation, tokenize=False)
66
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
67
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
68
 
69
- generate_kwargs = dict(
70
- input_ids=inputs,
71
- max_new_tokens = max_new_tokens,
72
- do_sample = False if temperature == 0 else True,
 
 
 
73
  top_p = top_p,
74
  top_k = top_k,
75
- temperature = temperature,
76
- streamer=streamer,
77
  repetition_penalty=penalty,
78
- pad_token_id = 10,
79
- )
80
-
81
- with torch.no_grad():
82
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
83
- thread.start()
84
-
85
- buffer = ""
86
- for new_text in streamer:
87
- buffer += new_text
88
- yield buffer
89
-
90
 
91
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
92
 
93
- with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
94
  gr.HTML(TITLE)
95
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
96
  gr.ChatInterface(
 
2
  import time
3
  import spaces
4
  import torch
 
5
  import gradio as gr
6
  from threading import Thread
7
 
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
9
 
10
+ TITLE = "<h1><center>Mistral-lab</center></h1>"
11
 
12
  PLACEHOLDER = """
13
  <center>
14
+ <p>Chat with Mistral AI LLM.</p>
15
  </center>
16
  """
17
 
 
28
  }
29
  """
30
 
31
+ from huggingface_hub import snapshot_download
32
+ from pathlib import Path
33
+
34
+ mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct')
35
+ mistral_models_path.mkdir(parents=True, exist_ok=True)
36
+
37
+ snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path)
38
+
39
+ from mistral_inference.transformer import Transformer
40
+ from mistral_inference.generate import generate
41
+
42
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
43
+ from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
44
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
45
+
46
  device = "cuda" # for GPU usage or "cpu" for CPU usage
47
 
48
+ tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
49
+ model = Transformer.from_folder(mistral_models_path)
50
+
 
 
 
51
 
52
  @spaces.GPU()
53
  def stream_chat(
 
64
 
65
  conversation = []
66
  for prompt, answer in history:
67
+ conversation.append(UserMessage(content=prompt))
68
+ conversation.append(AssistantMessage(content=answer))
69
+ conversation.append(UserMessage(content=message))
 
 
 
70
 
71
+ completion_request = ChatCompletionRequest(messages=conversation)
 
 
72
 
73
+ tokens = tokenizer.encode_chat_completion(completion_request).tokens
74
+
75
+ out_tokens, _ = generate(
76
+ [tokens],
77
+ model,
78
+ max_tokens=max_new_tokens,
79
+ temperature=temperature,
80
  top_p = top_p,
81
  top_k = top_k,
 
 
82
  repetition_penalty=penalty,
83
+ eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
84
+
85
+ result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
86
+
87
+ return result
 
 
 
 
 
 
 
88
 
89
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
90
 
91
+ with gr.Blocks(css=CSS, theme="ocean") as demo:
92
  gr.HTML(TITLE)
93
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
94
  gr.ChatInterface(