CSB261 commited on
Commit
d8731c3
1 Parent(s): edfa641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -5
app.py CHANGED
@@ -1,6 +1,10 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import os
 
 
 
 
4
 
5
  MODELS = {
6
  "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
@@ -11,16 +15,32 @@ MODELS = {
11
  "Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
12
  "Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
13
  "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
14
- "Cohere Aya-23-35B": "CohereForAI/aya-23-35B"
 
15
  }
16
 
17
  def get_client(model_name):
 
 
18
  model_id = MODELS[model_name]
19
  hf_token = os.getenv("HF_TOKEN")
20
  if not hf_token:
21
  raise ValueError("HF_TOKEN environment variable is required")
22
  return InferenceClient(model_id, token=hf_token)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def respond(
25
  message,
26
  chat_history,
@@ -31,7 +51,14 @@ def respond(
31
  system_message,
32
  ):
33
  try:
34
- client = get_client(model_name)
 
 
 
 
 
 
 
35
  except ValueError as e:
36
  chat_history.append((message, str(e)))
37
  return chat_history
@@ -44,7 +71,6 @@ def respond(
44
 
45
  try:
46
  if "Cohere" in model_name:
47
- # Cohere 모델을 위한 비스트리밍 처리
48
  response = client.chat_completion(
49
  messages,
50
  max_tokens=max_tokens,
@@ -55,7 +81,6 @@ def respond(
55
  chat_history.append((message, assistant_message))
56
  yield chat_history
57
  else:
58
- # 다른 모델들을 위한 스트리밍 처리
59
  stream = client.chat_completion(
60
  messages,
61
  max_tokens=max_tokens,
@@ -115,4 +140,4 @@ with gr.Blocks() as demo:
115
  clear_button.click(clear_conversation, outputs=chatbot, queue=False)
116
 
117
  if __name__ == "__main__":
118
- demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  import os
4
+ import openai # OpenAI API 클라이언트 추가
5
+
6
+ # OpenAI API 키 설정
7
+ openai.api_key = os.getenv("OPENAI_API_KEY")
8
 
9
  MODELS = {
10
  "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
 
15
  "Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
16
  "Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
17
  "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
18
+ "Cohere Aya-23-35B": "CohereForAI/aya-23-35B",
19
+ "ChatGPT-4o-mini": "gpt-4o-mini" # ChatGPT-4o-mini 모델 추가
20
  }
21
 
22
  def get_client(model_name):
23
+ if model_name == "ChatGPT-4o-mini":
24
+ return None # OpenAI API는 따로 클라이언트를 생성할 필요 없음
25
  model_id = MODELS[model_name]
26
  hf_token = os.getenv("HF_TOKEN")
27
  if not hf_token:
28
  raise ValueError("HF_TOKEN environment variable is required")
29
  return InferenceClient(model_id, token=hf_token)
30
 
31
+ def call_openai_api(content, system_message, max_tokens, temperature, top_p):
32
+ response = openai.ChatCompletion.create(
33
+ model="gpt-4o-mini",
34
+ messages=[
35
+ {"role": "system", "content": system_message},
36
+ {"role": "user", "content": content},
37
+ ],
38
+ max_tokens=max_tokens,
39
+ temperature=temperature,
40
+ top_p=top_p,
41
+ )
42
+ return response.choices[0].message['content']
43
+
44
  def respond(
45
  message,
46
  chat_history,
 
51
  system_message,
52
  ):
53
  try:
54
+ if model_name == "ChatGPT-4o-mini":
55
+ assistant_message = call_openai_api(message, system_message, max_tokens, temperature, top_p)
56
+ chat_history.append((message, assistant_message))
57
+ yield chat_history
58
+ else:
59
+ client = get_client(model_name)
60
+ if client is None:
61
+ raise ValueError(f"No client available for model: {model_name}")
62
  except ValueError as e:
63
  chat_history.append((message, str(e)))
64
  return chat_history
 
71
 
72
  try:
73
  if "Cohere" in model_name:
 
74
  response = client.chat_completion(
75
  messages,
76
  max_tokens=max_tokens,
 
81
  chat_history.append((message, assistant_message))
82
  yield chat_history
83
  else:
 
84
  stream = client.chat_completion(
85
  messages,
86
  max_tokens=max_tokens,
 
140
  clear_button.click(clear_conversation, outputs=chatbot, queue=False)
141
 
142
  if __name__ == "__main__":
143
+ demo.launch()