Hjgugugjhuhjggg commited on
Commit
e22bf4e
·
verified ·
1 Parent(s): c3e3686

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -46
app.py CHANGED
@@ -17,7 +17,6 @@ hf_hub_download(
17
  )
18
 
19
  # 推論関数
20
- @spaces.GPU()
21
  def respond(
22
  message,
23
  history: list[tuple[str, str]],
@@ -28,59 +27,79 @@ def respond(
28
  top_p,
29
  top_k,
30
  repeat_penalty,
 
31
  ):
32
  chat_template = MessagesFormatterType.GEMMA_2
33
 
34
- llm = Llama(
35
- model_path=f"models/{model}",
36
- flash_attn=True,
37
- n_gpu_layers=81,
38
- n_batch=1024,
39
- n_ctx=8192,
40
- )
41
- provider = LlamaCppPythonProvider(llm)
 
 
 
 
 
 
 
 
 
42
 
43
- agent = LlamaCppAgent(
44
- provider,
45
- system_prompt=f"{system_message}",
46
- predefined_messages_formatter_type=chat_template,
47
- debug_output=True
48
- )
 
 
49
 
50
- settings = provider.get_provider_default_settings()
51
- settings.temperature = temperature
52
- settings.top_k = top_k
53
- settings.top_p = top_p
54
- settings.max_tokens = max_tokens
55
- settings.repeat_penalty = repeat_penalty
56
- settings.stream = True
57
 
58
- messages = BasicChatHistory()
59
 
60
- for msn in history:
61
- user = {
62
- 'role': Roles.user,
63
- 'content': msn[0]
64
- }
65
- assistant = {
66
- 'role': Roles.assistant,
67
- 'content': msn[1]
68
- }
69
- messages.add_message(user)
70
- messages.add_message(assistant)
71
 
72
- stream = agent.get_chat_response(
73
- message,
74
- llm_sampling_settings=settings,
75
- chat_history=messages,
76
- returns_streaming_generator=True,
77
- print_output=False
78
- )
79
 
80
- outputs = ""
81
- for output in stream:
82
- outputs += output
83
- yield outputs
 
 
 
 
 
 
 
 
84
 
85
  # Gradioのインターフェースを作成
86
  def create_interface(model_name, description):
@@ -137,4 +156,4 @@ with demo:
137
  interface.render()
138
 
139
  if __name__ == "__main__":
140
- demo.launch()
 
17
  )
18
 
19
  # 推論関数
 
20
  def respond(
21
  message,
22
  history: list[tuple[str, str]],
 
27
  top_p,
28
  top_k,
29
  repeat_penalty,
30
+ use_gpu: bool = True # Añadir parámetro para elegir entre GPU y CPU
31
  ):
32
  chat_template = MessagesFormatterType.GEMMA_2
33
 
34
+ try:
35
+ # Si no hay GPU, usar CPU
36
+ if use_gpu:
37
+ llm = Llama(
38
+ model_path=f"models/{model}",
39
+ flash_attn=True,
40
+ n_gpu_layers=81,
41
+ n_batch=1024,
42
+ n_ctx=8192,
43
+ )
44
+ else:
45
+ llm = Llama(
46
+ model_path=f"models/{model}",
47
+ flash_attn=False, # Desactivar el uso de GPU
48
+ n_batch=1024,
49
+ n_ctx=8192,
50
+ )
51
 
52
+ provider = LlamaCppPythonProvider(llm)
53
+
54
+ agent = LlamaCppAgent(
55
+ provider,
56
+ system_prompt=f"{system_message}",
57
+ predefined_messages_formatter_type=chat_template,
58
+ debug_output=True
59
+ )
60
 
61
+ settings = provider.get_provider_default_settings()
62
+ settings.temperature = temperature
63
+ settings.top_k = top_k
64
+ settings.top_p = top_p
65
+ settings.max_tokens = max_tokens
66
+ settings.repeat_penalty = repeat_penalty
67
+ settings.stream = True
68
 
69
+ messages = BasicChatHistory()
70
 
71
+ for msn in history:
72
+ user = {
73
+ 'role': Roles.user,
74
+ 'content': msn[0]
75
+ }
76
+ assistant = {
77
+ 'role': Roles.assistant,
78
+ 'content': msn[1]
79
+ }
80
+ messages.add_message(user)
81
+ messages.add_message(assistant)
82
 
83
+ stream = agent.get_chat_response(
84
+ message,
85
+ llm_sampling_settings=settings,
86
+ chat_history=messages,
87
+ returns_streaming_generator=True,
88
+ print_output=False
89
+ )
90
 
91
+ outputs = ""
92
+ for output in stream:
93
+ outputs += output
94
+ yield outputs
95
+ except Exception as e:
96
+ # Si se supera la cuota de GPU, retornar mensaje de error o intentar con CPU
97
+ if "You have exceeded your GPU quota" in str(e):
98
+ print("GPU quota exceeded, switching to CPU mode.")
99
+ yield "Error: Exceeded GPU quota, switching to CPU. Please wait a moment..."
100
+ return respond(message, history, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty, use_gpu=False)
101
+ else:
102
+ yield f"An error occurred: {str(e)}"
103
 
104
  # Gradioのインターフェースを作成
105
  def create_interface(model_name, description):
 
156
  interface.render()
157
 
158
  if __name__ == "__main__":
159
+ demo.launch()