vilarin commited on
Commit
970d940
·
verified ·
1 Parent(s): 12a34f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -23
app.py CHANGED
@@ -5,6 +5,16 @@ import torch
5
  import gradio as gr
6
  from threading import Thread
7
 
 
 
 
 
 
 
 
 
 
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
 
10
  TITLE = "<h1><center>Mistral-lab</center></h1>"
@@ -15,25 +25,32 @@ PLACEHOLDER = """
15
  </center>
16
  """
17
 
18
- from huggingface_hub import snapshot_download
19
- from pathlib import Path
 
 
 
 
 
 
 
 
 
20
 
 
 
21
  mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct')
22
  mistral_models_path.mkdir(parents=True, exist_ok=True)
23
 
24
  snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path)
25
 
26
- from mistral_inference.transformer import Transformer
27
- from mistral_inference.generate import generate
28
-
29
- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
30
- from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
31
- from mistral_common.protocol.instruct.request import ChatCompletionRequest
32
-
33
- device = "cuda" # for GPU usage or "cpu" for CPU usage
34
-
35
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
36
- model = Transformer.from_folder(mistral_models_path)
 
 
 
37
 
38
 
39
  @spaces.GPU()
@@ -64,12 +81,23 @@ def stream_chat(
64
  eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
65
 
66
  result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
67
-
68
- return result
 
 
69
 
70
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
 
 
 
 
 
 
 
 
71
 
72
- with gr.Blocks(theme="ocean") as demo:
73
  gr.HTML(TITLE)
74
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
75
  gr.ChatInterface(
@@ -95,13 +123,6 @@ with gr.Blocks(theme="ocean") as demo:
95
  render=False,
96
  ),
97
  ],
98
- examples=[
99
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
100
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
101
- ["Tell me a random fun fact about the Roman Empire."],
102
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
103
- ],
104
- cache_examples=False,
105
  )
106
 
107
 
 
5
  import gradio as gr
6
  from threading import Thread
7
 
8
+ from huggingface_hub import snapshot_download
9
+ from pathlib import Path
10
+
11
+ from mistral_inference.transformer import Transformer
12
+ from mistral_inference.generate import generate
13
+
14
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
15
+ from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
16
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
17
+
18
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
 
20
  TITLE = "<h1><center>Mistral-lab</center></h1>"
 
25
  </center>
26
  """
27
 
28
+ CSS = """
29
+ .duplicate-button {
30
+ margin: auto !important;
31
+ color: white !important;
32
+ background: black !important;
33
+ border-radius: 100vh !important;
34
+ }
35
+ h3 {
36
+ text-align: center;
37
+ }
38
+ """
39
 
40
+
41
+ # download model
42
  mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct')
43
  mistral_models_path.mkdir(parents=True, exist_ok=True)
44
 
45
  snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path)
46
 
47
+ # tokenizer
48
+ device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage
 
 
 
 
 
 
 
49
  tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
50
+ model = Transformer.from_folder(
51
+ mistral_models_path,
52
+ device=device,
53
+ dtype=torch.bfloat16)
54
 
55
 
56
  @spaces.GPU()
 
81
  eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
82
 
83
  result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
84
+
85
+ for i in range(len(result)):
86
+ time.sleep(0.05)
87
+ yield result[: i + 1]
88
 
89
+ chatbot = gr.Chatbot(
90
+ height=600,
91
+ placeholder=PLACEHOLDER,
92
+ examples=[
93
+ {"text": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."},
94
+ {"text": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."},
95
+ {"text": "Tell me a random fun fact about the Roman Empire."},
96
+ {"text": "Show me a code snippet of a website's sticky header in CSS and JavaScript."},
97
+ ],
98
+ )
99
 
100
+ with gr.Blocks(theme="ocean", css=CSS) as demo:
101
  gr.HTML(TITLE)
102
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
103
  gr.ChatInterface(
 
123
  render=False,
124
  ),
125
  ],
 
 
 
 
 
 
 
126
  )
127
 
128