minhnguyent546 commited on
Commit
5faeb60
·
unverified ·
1 Parent(s): cdb6de6

feat: load model when user entered the chat

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -33,13 +33,7 @@ model_paths = {
33
  'filename': 'Med-Alpaca-2-7B-chat.F16.gguf',
34
  },
35
  }
36
-
37
- model = Llama.from_pretrained(
38
- **model_paths[DEFAULT_MODEL],
39
- n_ctx=4096,
40
- n_threads=4,
41
- cache_dir='./.hf_cache'
42
- )
43
 
44
  def generate_alpaca_prompt(
45
  instruction: str,
@@ -77,6 +71,8 @@ def chat_completion(
77
  top_k: int,
78
  top_p: float,
79
  ):
 
 
80
  prompt = generate_alpaca_prompt(instruction=message)
81
  response_iterator = model(
82
  prompt,
@@ -93,7 +89,7 @@ def chat_completion(
93
  partial_response += token['choices'][0]['text']
94
  yield partial_response
95
 
96
- def on_model_changed(model_name: str):
97
  global model
98
  if 'model' in globals():
99
  del model
@@ -102,7 +98,7 @@ def on_model_changed(model_name: str):
102
  **model_paths[model_name],
103
  n_ctx=4096,
104
  n_threads=4,
105
- cache_dir='./hf-cache'
106
  )
107
 
108
  app_title_mark = gr.Markdown(f"""<center><font size=18>{model_name}</center>""")
@@ -167,7 +163,7 @@ def main() -> None:
167
  ],
168
  )
169
 
170
- model_radio.change(on_model_changed, inputs=[model_radio], outputs=[app_title_mark, chatbot])
171
 
172
  demo.queue(api_open=False, default_concurrency_limit=20)
173
  demo.launch(max_threads=5, share=os.environ.get('GRADIO_SHARE', False))
 
33
  'filename': 'Med-Alpaca-2-7B-chat.F16.gguf',
34
  },
35
  }
36
+ model = None
 
 
 
 
 
 
37
 
38
  def generate_alpaca_prompt(
39
  instruction: str,
 
71
  top_k: int,
72
  top_p: float,
73
  ):
74
+ if model is None:
75
+ reload_model(DEFAULT_MODEL)
76
  prompt = generate_alpaca_prompt(instruction=message)
77
  response_iterator = model(
78
  prompt,
 
89
  partial_response += token['choices'][0]['text']
90
  yield partial_response
91
 
92
+ def reload_model(model_name: str):
93
  global model
94
  if 'model' in globals():
95
  del model
 
98
  **model_paths[model_name],
99
  n_ctx=4096,
100
  n_threads=4,
101
+ cache_dir='./.hf_cache'
102
  )
103
 
104
  app_title_mark = gr.Markdown(f"""<center><font size=18>{model_name}</center>""")
 
163
  ],
164
  )
165
 
166
+ model_radio.change(reload_model, inputs=[model_radio], outputs=[app_title_mark, chatbot])
167
 
168
  demo.queue(api_open=False, default_concurrency_limit=20)
169
  demo.launch(max_threads=5, share=os.environ.get('GRADIO_SHARE', False))