dingusagar commited on
Commit
cd08073
·
verified ·
1 Parent(s): 9e394f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -2
app.py CHANGED
@@ -3,6 +3,15 @@ import subprocess
3
  import time
4
  from ollama import chat
5
  from ollama import ChatResponse
 
 
 
 
 
 
 
 
 
6
 
7
  # Default model
8
  OLLAMA_MODEL = "llama3.2:3b"
@@ -105,6 +114,65 @@ Use second person terms like you in the explanation.
105
  yield response
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Separate function for Ollama response
109
  def gradio_ollama_interface(prompt, bert_class=""):
110
  return ask_ollama(prompt, expected_class=bert_class)
@@ -123,14 +191,22 @@ def gradio_interface(prompt, selected_model):
123
  yield initial_response
124
  for chunk in ask_ollama(prompt, expected_class=bert_label_map[label]):
125
  yield initial_response + "\n" + chunk
 
 
 
 
 
 
 
126
  else:
127
  return "Something went wrong. Select the correct model configuration from settings. "
128
 
129
  MODEL_CHOICE_BERT_LLAMA = "Fine-tuned BERT (classification) + Llama 3.2 3B (explanation)"
 
130
  MODEL_CHOICE_BERT = "Fine-tuned BERT (classification only)"
131
  MODEL_CHOICE_LLAMA = "Llama 3.2 3B (classification + explanation)"
132
 
133
- MODEL_OPTIONS = [MODEL_CHOICE_BERT_LLAMA, MODEL_CHOICE_LLAMA, MODEL_CHOICE_BERT]
134
 
135
  # Example texts
136
  EXAMPLES = [
@@ -186,5 +262,5 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.green, secon
186
 
187
  # Launch the app
188
  if __name__ == "__main__":
189
- start_ollama_server()
190
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
3
  import time
4
  from ollama import chat
5
  from ollama import ChatResponse
6
+ from huggingface_hub import InferenceClient
7
+ import os
8
+
9
+ # HF client
10
+ hf_api_key = os.getenv("HF_API_KEY", None)
11
+ if not hf_api_key:
12
+ print("HF_API_KEY environment variable not set, will not be able to use Inference API")
13
+ hf_client = InferenceClient(api_key=hf_api_key)
14
+
15
 
16
  # Default model
17
  OLLAMA_MODEL = "llama3.2:3b"
 
114
  yield response
115
 
116
 
117
+
118
+ def ask_hf_inference_client(question, expected_class=""):
119
+ print(f"Getting response from Ollama")
120
+ classify_and_explain_prompt = f"""
121
+ ### You are an unbiased expert from subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
122
+ The community uses the following acronyms.
123
+ AITA : Am I the asshole? Usually posted in the question.
124
+ YTA : You are the asshole in this situation.
125
+ NTA : You are not the asshole in this situation.
126
+
127
+ ### The task for you label YTA or NTA for the given text. Give a short explanation for the label. Be brutally honest and unbiased. Base your explanation entirely on the given text only.
128
+
129
+ If the label is YTA, also explain what could the user have done better.
130
+ ### The output format is as follows:
131
+ "YTA" or "NTA", a short explanation.
132
+
133
+ ### Situation : {question}
134
+ ### Response :"""
135
+
136
+ explain_only_prompt = f"""
137
+ ### You know about the subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
138
+ The community uses the following acronyms.
139
+ AITA : Am I the asshole? Usually posted in the question.
140
+ YTA : You are the asshole in this situation.
141
+ NTA : You are not the asshole in this situation.
142
+
143
+ ### The task for you explain why a particular situation was tagged as NTA or YTA by most users. I will give the situation as well as the NTA or YTA tag. just give your explanation for the label. Be nice but give a brutally honest and unbiased view. Base your explanation entirely on the given text and the label tag only. Do not assume anything extra.
144
+ Use second person terms like you in the explanation.
145
+
146
+ ### Situation : {question}
147
+ ### Label Tag : {expected_class}
148
+ ### Explanation for {expected_class} :"""
149
+
150
+ if expected_class == "":
151
+ prompt = classify_and_explain_prompt
152
+ else:
153
+ prompt = explain_only_prompt
154
+
155
+ print(f"Prompt to HF_Inference API : {prompt}")
156
+
157
+ messages = [
158
+ {
159
+ "role": "user",
160
+ "content": prompt
161
+ }
162
+ ]
163
+
164
+ stream = hf_client.chat.completions.create(
165
+ model="meta-llama/Llama-3.2-3B-Instruct",
166
+ messages=messages,
167
+ max_tokens=500,
168
+ stream=True
169
+ )
170
+
171
+ for chunk in stream:
172
+ yield chunk.choices[0].delta.content
173
+
174
+
175
+
176
  # Separate function for Ollama response
177
  def gradio_ollama_interface(prompt, bert_class=""):
178
  return ask_ollama(prompt, expected_class=bert_class)
 
191
  yield initial_response
192
  for chunk in ask_ollama(prompt, expected_class=bert_label_map[label]):
193
  yield initial_response + "\n" + chunk
194
+ elif selected_model == MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE:
195
+ label, confidence = ask_bert(prompt)
196
+ initial_response = f"Response from BERT model: {bert_label_map_formatted[label]} with confidence {confidence}%\n\nGenerating explanation using Llama model...\n"
197
+ yield initial_response
198
+ for chunk in ask_hf_inference_client(prompt, expected_class=bert_label_map[label]):
199
+ initial_response += chunk
200
+ yield initial_response
201
  else:
202
  return "Something went wrong. Select the correct model configuration from settings. "
203
 
204
  MODEL_CHOICE_BERT_LLAMA = "Fine-tuned BERT (classification) + Llama 3.2 3B (explanation)"
205
+ MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE = "Fine-tuned BERT (classification) + Llama 3.2 3B Inference api (fast explanation)"
206
  MODEL_CHOICE_BERT = "Fine-tuned BERT (classification only)"
207
  MODEL_CHOICE_LLAMA = "Llama 3.2 3B (classification + explanation)"
208
 
209
+ MODEL_OPTIONS = [MODEL_CHOICE_BERT_LLAMA_HF_INFERENCE, MODEL_CHOICE_BERT_LLAMA, MODEL_CHOICE_LLAMA, MODEL_CHOICE_BERT]
210
 
211
  # Example texts
212
  EXAMPLES = [
 
262
 
263
  # Launch the app
264
  if __name__ == "__main__":
265
+ # start_ollama_server()
266
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)