Spaces:
Running
Running
dingusagar
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import subprocess
|
3 |
+
import time
|
4 |
+
from ollama import chat
|
5 |
+
from ollama import ChatResponse
|
6 |
+
|
7 |
+
# Default model
|
8 |
+
OLLAMA_MODEL = "llama3.2:3b"
|
9 |
+
|
10 |
+
# Load BERT MODEL
|
11 |
+
from transformers import pipeline, DistilBertTokenizerFast
|
12 |
+
|
13 |
+
# Path to your locally saved model
|
14 |
+
# bert_model_path = "fine_tuned_aita_classifier"
|
15 |
+
bert_model_path = "dingusagar/distillbert-aita-classifier"
|
16 |
+
|
17 |
+
tokenizer = DistilBertTokenizerFast.from_pretrained(bert_model_path)
|
18 |
+
classifier = pipeline(
|
19 |
+
"text-classification",
|
20 |
+
model=bert_model_path, # Path to your locally saved model
|
21 |
+
tokenizer=tokenizer, # Use the tokenizer saved with the model
|
22 |
+
truncation=True
|
23 |
+
)
|
24 |
+
|
25 |
+
bert_label_map = {
|
26 |
+
'LABEL_0': 'YTA',
|
27 |
+
'LABEL_1': 'NTA',
|
28 |
+
}
|
29 |
+
|
30 |
+
def ask_bert(prompt):
|
31 |
+
print(f"Getting response from Fine-tuned BERT")
|
32 |
+
result = classifier([prompt])[0]
|
33 |
+
label = bert_label_map.get(result['label'])
|
34 |
+
confidence = f"{result['score']*100:.2f}"
|
35 |
+
return label, confidence
|
36 |
+
|
37 |
+
def start_ollama_server():
|
38 |
+
# Start Ollama server in the background
|
39 |
+
print("Starting Ollama server...")
|
40 |
+
subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
41 |
+
time.sleep(5) # Give some time for the server to start
|
42 |
+
|
43 |
+
# Pull the required model
|
44 |
+
print(f"Pulling the model: {OLLAMA_MODEL}")
|
45 |
+
subprocess.run(["ollama", "pull", OLLAMA_MODEL], check=True)
|
46 |
+
|
47 |
+
print("Starting the required model...")
|
48 |
+
subprocess.Popen(["ollama", "run", OLLAMA_MODEL], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
49 |
+
print("Ollama started model.")
|
50 |
+
|
51 |
+
def ask_ollama(question, expected_class=""):
|
52 |
+
print(f"Getting response from Ollama")
|
53 |
+
classify_and_explain_prompt = f"""
|
54 |
+
### You are an unbiased expert from subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
55 |
+
The community uses the following acronyms.
|
56 |
+
AITA : Am I the asshole? Usually posted in the question.
|
57 |
+
YTA : You are the asshole in this situation.
|
58 |
+
NTA : You are not the asshole in this situation.
|
59 |
+
|
60 |
+
### The task for you predict if most of the users would tag the given situation as YTA or NTA, give your personal opinion. Do not try to be nice, just give brutally honest and unbiased view. Base your decision entirely on the given text.
|
61 |
+
Use second person terms like you in the explanation.
|
62 |
+
|
63 |
+
### The output format is as follows:
|
64 |
+
"YTA" or "NTA", a short explanation.
|
65 |
+
|
66 |
+
|
67 |
+
### Situation : {question}
|
68 |
+
### Output :{expected_class}"""
|
69 |
+
|
70 |
+
explain_only_prompt = f"""
|
71 |
+
### You know about the subreddit community r/AmItheAsshole. In this community people post their life situations and ask if they are the asshole or not.
|
72 |
+
The community uses the following acronyms.
|
73 |
+
AITA : Am I the asshole? Usually posted in the question.
|
74 |
+
YTA : You are the asshole in this situation.
|
75 |
+
NTA : You are not the asshole in this situation.
|
76 |
+
|
77 |
+
### The task for you explain why a particular situation was tagged as NTA or YTA by most users. I will give the situation as well as the NTA or YTA tag. just give your explanation for the label. Do not try to be nice, just give brutally honest and unbiased view. Base your decision entirely on the given text and the label tag.
|
78 |
+
Use second person terms like you in the explanation.
|
79 |
+
|
80 |
+
### The output format is just the explanation for the label.
|
81 |
+
|
82 |
+
|
83 |
+
### Situation : {question}
|
84 |
+
### Label Tag : {expected_class}
|
85 |
+
### Explanation for {expected_class} :"""
|
86 |
+
|
87 |
+
if expected_class == "":
|
88 |
+
prompt = classify_and_explain_prompt
|
89 |
+
else:
|
90 |
+
prompt = explain_only_prompt
|
91 |
+
|
92 |
+
print(f"Prompt to llama : {prompt}")
|
93 |
+
response: ChatResponse = chat(model=OLLAMA_MODEL, messages=[
|
94 |
+
{
|
95 |
+
'role': 'user',
|
96 |
+
'content': prompt,
|
97 |
+
},
|
98 |
+
])
|
99 |
+
print(response['message']['content'])
|
100 |
+
return response['message']['content']
|
101 |
+
|
102 |
+
def gradio_interface(prompt, selected_model):
|
103 |
+
if selected_model == MODEL_CHOICE_LLAMA:
|
104 |
+
response = ask_ollama(prompt, selected_model)
|
105 |
+
elif selected_model == MODEL_CHOICE_BERT:
|
106 |
+
response, confidence = ask_bert(prompt)
|
107 |
+
response = f"{response} with confidence {confidence}"
|
108 |
+
elif selected_model == MODEL_CHOICE_BERT_LLAMA:
|
109 |
+
bert_response, confidence = ask_bert(prompt)
|
110 |
+
ollama_response = ask_ollama(prompt, expected_class=bert_response)
|
111 |
+
response = f"{bert_response} with {confidence}% confidence. \n {ollama_response}"
|
112 |
+
else:
|
113 |
+
response = "Something went wrong. Select the correct model configuration from settings. "
|
114 |
+
return response
|
115 |
+
|
116 |
+
MODEL_CHOICE_BERT_LLAMA = "Fine-tuned BERT (classification) + Llama 3.2 3B (explanation)"
|
117 |
+
MODEL_CHOICE_BERT = "Fine-tuned BERT (classification only)"
|
118 |
+
MODEL_CHOICE_LLAMA = "Llama 3.2 3B (classification + explanation)"
|
119 |
+
|
120 |
+
MODEL_OPTIONS = [MODEL_CHOICE_BERT_LLAMA, MODEL_CHOICE_LLAMA, MODEL_CHOICE_BERT]
|
121 |
+
|
122 |
+
# Example texts
|
123 |
+
EXAMPLES = [
|
124 |
+
"I refused to invite my coworker to my birthday party even though we’re part of the same friend group. AITA?",
|
125 |
+
"I didn't attend my best friend's wedding because I couldn't afford the trip. Now they are mad at me. AITA?",
|
126 |
+
"I told my coworker they were being unprofessional during a meeting in front of everyone. AITA?",
|
127 |
+
"I told my kid that she should become an engineer like me, she is into painting and wants to pursue arts. AITA? "
|
128 |
+
]
|
129 |
+
|
130 |
+
# Build the Gradio app
|
131 |
+
# with gr.Blocks(theme="JohnSmith9982/small_and_pretty") as demo:
|
132 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.green, secondary_hue=gr.themes.colors.purple)) as demo:
|
133 |
+
gr.Markdown("# AITA Classifier")
|
134 |
+
gr.Markdown(
|
135 |
+
"""### Ask this AI app if you are wrong in a situation. Describe the conflict you experienced, give both sides of the story and find out if you are right (NTA) or, you are the a**shole (YTA). Inspired by the subreddit [r/AmItheAsshole](https://www.reddit.com/r/AmItheAsshole/), this app tries to provide honest and unbiased assessments of user's life situations.
|
136 |
+
<sub>**Disclaimer:** The responses generated by this AI model are based on the training data derived from the subreddit posts and do not represent the views or opinions of the creators or authors. This was our fun little project, please don't take the generated responses too seriously :) </sub>
|
137 |
+
""")
|
138 |
+
|
139 |
+
# Add Accordion for settings
|
140 |
+
# with gr.Accordion("Settings", open=True):
|
141 |
+
# model_selector = gr.Dropdown(
|
142 |
+
# label="Select Models",
|
143 |
+
# choices=MODEL_OPTIONS,
|
144 |
+
# value=MODEL_CHOICE_BERT_LLAMA
|
145 |
+
# )
|
146 |
+
|
147 |
+
with gr.Row():
|
148 |
+
model_selector = gr.Dropdown(
|
149 |
+
label="Selected Model",
|
150 |
+
choices=MODEL_OPTIONS,
|
151 |
+
value=MODEL_CHOICE_BERT_LLAMA
|
152 |
+
)
|
153 |
+
|
154 |
+
with gr.Row():
|
155 |
+
input_prompt = gr.Textbox(label="Enter your situation here", placeholder="Am I the a**hole for...", lines=5)
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
# Add example texts
|
159 |
+
example = gr.Examples(
|
160 |
+
examples=EXAMPLES,
|
161 |
+
inputs=input_prompt,
|
162 |
+
label="Want to quickly try some example situations ?",
|
163 |
+
)
|
164 |
+
|
165 |
+
with gr.Row():
|
166 |
+
submit_button = gr.Button("Check A**hole or not!", variant="primary")
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
output_response = gr.Textbox(label="Response", lines=10, placeholder="""Result will be YTA (you are the A**hole) or NTA(Not the A**shole)""")
|
170 |
+
|
171 |
+
# Link the button click to the interface function
|
172 |
+
submit_button.click(gradio_interface, inputs=[input_prompt, model_selector], outputs=output_response)
|
173 |
+
|
174 |
+
# Launch the app
|
175 |
+
if __name__ == "__main__":
|
176 |
+
start_ollama_server()
|
177 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|