SUHHHH commited on
Commit
95b089d
1 Parent(s): 1754094

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -44
app.py CHANGED
@@ -1,6 +1,10 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import os
 
 
 
 
4
 
5
  MODELS = {
6
  "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
@@ -11,16 +15,32 @@ MODELS = {
11
  "Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
12
  "Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
13
  "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
14
- "Cohere Aya-23-35B": "CohereForAI/aya-23-35B"
 
15
  }
16
 
17
  def get_client(model_name):
 
 
18
  model_id = MODELS[model_name]
19
  hf_token = os.getenv("HF_TOKEN")
20
  if not hf_token:
21
  raise ValueError("HF_TOKEN environment variable is required")
22
  return InferenceClient(model_id, token=hf_token)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def respond(
25
  message,
26
  chat_history,
@@ -30,52 +50,64 @@ def respond(
30
  top_p,
31
  system_message,
32
  ):
33
- try:
34
- client = get_client(model_name)
35
- except ValueError as e:
36
- chat_history.append((message, str(e)))
37
- return chat_history
38
-
39
- messages = [{"role": "system", "content": system_message}]
40
- for human, assistant in chat_history:
41
- messages.append({"role": "user", "content": human})
42
- messages.append({"role": "assistant", "content": assistant})
43
- messages.append({"role": "user", "content": message})
44
-
45
- try:
46
- if "Cohere" in model_name:
47
- # Cohere 모델을 위한 비스트리밍 처리
48
- response = client.chat_completion(
49
- messages,
50
- max_tokens=max_tokens,
51
- temperature=temperature,
52
- top_p=top_p,
53
  )
54
- assistant_message = response.choices[0].message.content
55
  chat_history.append((message, assistant_message))
56
  yield chat_history
57
- else:
58
- # 다른 모델들을 위한 스트리밍 처리
59
- stream = client.chat_completion(
60
- messages,
61
- max_tokens=max_tokens,
62
- temperature=temperature,
63
- top_p=top_p,
64
- stream=True,
65
- )
66
- partial_message = ""
67
- for response in stream:
68
- if response.choices[0].delta.content is not None:
69
- partial_message += response.choices[0].delta.content
70
- if len(chat_history) > 0 and chat_history[-1][0] == message:
71
- chat_history[-1] = (message, partial_message)
72
- else:
73
- chat_history.append((message, partial_message))
74
- yield chat_history
75
- except Exception as e:
76
- error_message = f"An error occurred: {str(e)}"
77
- chat_history.append((message, error_message))
78
- yield chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def clear_conversation():
81
  return []
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import os
4
+ import openai # OpenAI API를 사용하기 위해 추가
5
+
6
+ # OpenAI API 클라이언트 설정
7
+ openai.api_key = os.getenv("OPENAI_API_KEY")
8
 
9
  MODELS = {
10
  "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
 
15
  "Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
16
  "Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
17
  "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
18
+ "Cohere Aya-23-35B": "CohereForAI/aya-23-35B",
19
+ "GPT-4o Mini": "gpt-4o-mini" # GPT-4o Mini 모델 추가
20
  }
21
 
22
  def get_client(model_name):
23
+ if model_name == "GPT-4o Mini":
24
+ return None # OpenAI 모델을 위해 HuggingFace 클라이언트를 사용하지 않음
25
  model_id = MODELS[model_name]
26
  hf_token = os.getenv("HF_TOKEN")
27
  if not hf_token:
28
  raise ValueError("HF_TOKEN environment variable is required")
29
  return InferenceClient(model_id, token=hf_token)
30
 
31
+ def call_openai_api(content, system_message, max_tokens, temperature, top_p):
32
+ response = openai.ChatCompletion.create(
33
+ model="gpt-4o-mini", # OpenAI 모델 사용
34
+ messages=[
35
+ {"role": "system", "content": system_message},
36
+ {"role": "user", "content": content},
37
+ ],
38
+ max_tokens=max_tokens,
39
+ temperature=temperature,
40
+ top_p=top_p,
41
+ )
42
+ return response.choices[0].message['content']
43
+
44
  def respond(
45
  message,
46
  chat_history,
 
50
  top_p,
51
  system_message,
52
  ):
53
+ if model_name == "GPT-4o Mini":
54
+ try:
55
+ assistant_message = call_openai_api(
56
+ message, system_message, max_tokens, temperature, top_p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
 
58
  chat_history.append((message, assistant_message))
59
  yield chat_history
60
+ except Exception as e:
61
+ error_message = f"An error occurred with GPT-4o Mini: {str(e)}"
62
+ chat_history.append((message, error_message))
63
+ yield chat_history
64
+ else:
65
+ try:
66
+ client = get_client(model_name)
67
+ except ValueError as e:
68
+ chat_history.append((message, str(e)))
69
+ return chat_history
70
+
71
+ messages = [{"role": "system", "content": system_message}]
72
+ for human, assistant in chat_history:
73
+ messages.append({"role": "user", "content": human})
74
+ messages.append({"role": "assistant", "content": assistant})
75
+ messages.append({"role": "user", "content": message})
76
+
77
+ try:
78
+ if "Cohere" in model_name:
79
+ # Cohere 모델을 위한 비스트리밍 처리
80
+ response = client.chat_completion(
81
+ messages,
82
+ max_tokens=max_tokens,
83
+ temperature=temperature,
84
+ top_p=top_p,
85
+ )
86
+ assistant_message = response.choices[0].message.content
87
+ chat_history.append((message, assistant_message))
88
+ yield chat_history
89
+ else:
90
+ # 다른 모델들을 위한 스트리밍 처리
91
+ stream = client.chat_completion(
92
+ messages,
93
+ max_tokens=max_tokens,
94
+ temperature=temperature,
95
+ top_p=top_p,
96
+ stream=True,
97
+ )
98
+ partial_message = ""
99
+ for response in stream:
100
+ if response.choices[0].delta.content is not None:
101
+ partial_message += response.choices[0].delta.content
102
+ if len(chat_history) > 0 and chat_history[-1][0] == message:
103
+ chat_history[-1] = (message, partial_message)
104
+ else:
105
+ chat_history.append((message, partial_message))
106
+ yield chat_history
107
+ except Exception as e:
108
+ error_message = f"An error occurred: {str(e)}"
109
+ chat_history.append((message, error_message))
110
+ yield chat_history
111
 
112
  def clear_conversation():
113
  return []