SentinelAI102 / app.py
Shreyas94's picture
Update app.py
021302a verified
import os
import urllib
import requests
from bs4 import BeautifulSoup
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import feedparser
# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Define device and load model and tokenizer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
# Load model and tokenizer
try:
logger.debug("Attempting to load the model and tokenizer")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
logger.debug("Model and tokenizer loaded successfully")
except Exception as e:
logger.error(f"Error loading model and tokenizer: {e}")
model = None
tokenizer = None
# Function to fetch news from Google News RSS feed
def fetch_news(term, num_results=2):
logger.debug(f"Fetching news for term: {term}")
encoded_term = urllib.parse.quote(term)
url = f"https://news.google.com/rss/search?q={encoded_term}"
feed = feedparser.parse(url)
results = []
for entry in feed.entries[:num_results]:
results.append({"link": entry.link, "text": entry.title})
logger.debug(f"Fetched news results: {results}")
return results
# Function to perform a Google search and return the results
def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None):
logger.debug(f"Starting search for term: {term}")
escaped_term = urllib.parse.quote_plus(term)
start = 0
all_results = []
max_chars_per_page = 8000
with requests.Session() as session:
while start < num_results:
try:
resp = session.get(
url="https://www.google.com/search",
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
params={
"q": term,
"num": num_results - start,
"hl": lang,
"start": start,
"safe": safe,
},
timeout=timeout,
verify=ssl_verify,
)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html.parser")
result_block = soup.find_all("div", attrs={"class": "g"})
if not result_block:
start += 1
continue
for result in result_block:
link = result.find("a", href=True)
if link:
link = link["href"]
try:
webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"})
webpage.raise_for_status()
visible_text = extract_text_from_webpage(webpage.text)
if len(visible_text) > max_chars_per_page:
visible_text = visible_text[:max_chars_per_page] + "..."
all_results.append({"link": link, "text": visible_text})
except requests.exceptions.RequestException as e:
logger.error(f"Error fetching or processing {link}: {e}")
all_results.append({"link": link, "text": None})
else:
all_results.append({"link": None, "text": None})
start += len(result_block)
except Exception as e:
logger.error(f"Error during search: {e}")
break
logger.debug(f"Search results: {all_results}")
return all_results
# Function to extract visible text from HTML content
def extract_text_from_webpage(html_content):
soup = BeautifulSoup(html_content, "html.parser")
for tag in soup(["script", "style", "header", "footer", "nav"]):
tag.extract()
visible_text = soup.get_text(strip=True)
return visible_text
# Function to format the prompt for the language model
def format_prompt(user_prompt, chat_history):
logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}")
prompt = ""
for item in chat_history:
prompt += f"User: {item[0]}\nAssistant: {item[1]}\n"
prompt += f"User: {user_prompt}\nAssistant:"
logger.debug(f"Formatted prompt: {prompt}")
return prompt
# Function for model inference
def model_inference(
user_prompt,
chat_history,
web_search,
temperature,
max_new_tokens,
repetition_penalty,
top_p,
tokenizer # Pass tokenizer as an argument
):
logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}")
if not isinstance(user_prompt, dict):
logger.error("Invalid input format. Expected a dictionary.")
return "Invalid input format. Expected a dictionary."
if "files" not in user_prompt:
user_prompt["files"] = []
if not user_prompt["files"]:
if web_search:
logger.debug("Performing news search")
news_results = fetch_news(user_prompt["text"])
news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results])
formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
if model:
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
response = "Model is not available. Please try again later."
logger.debug(f"Model response: {response}")
return response
else:
formatted_prompt = format_prompt(user_prompt["text"], chat_history)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
if model:
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
response = "Model is not available. Please try again later."
logger.debug(f"Model response: {response}")
return response
else:
return "Image input not supported in this implementation."
# Define Gradio interface components
max_new_tokens = gr.Slider(
minimum=1,
maximum=16000,
value=2048,
step=64,
interactive=True,
label="Maximum number of new tokens to generate",
)
repetition_penalty = gr.Slider(
minimum=0.01,
maximum=5.0,
value=1,
step=0.01,
interactive=True,
label="Repetition penalty",
info="1.0 is equivalent to no penalty",
)
decoding_strategy = gr.Radio(
[
"Greedy",
"Top P Sampling",
],
value="Top P Sampling",
label="Decoding strategy",
interactive=True,
info="Higher values are equivalent to sampling more low-probability tokens.",
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.5,
step=0.05,
visible=True,
interactive=True,
label="Sampling temperature",
info="Higher values will produce more diverse outputs.",
)
top_p = gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.9,
step=0.01,
visible=True,
interactive=True,
label="Top P",
info="Higher values are equivalent to sampling more low-probability tokens.",
)
# Create a chatbot interface
chatbot = gr.Chatbot(
label="OpenGPT-4o-Chatty",
show_copy_button=True,
likeable=True,
layout="panel"
)
# Define Gradio interface
def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p):
# Ensure the tokenizer is accessible within the function scope
global tokenizer
# Wrap the user input in a dictionary as expected by the model_inference function
user_prompt = {"text": user_input, "files": []}
# Perform model inference
response = model_inference(
user_prompt=user_prompt,
chat_history=history,
web_search=web_search,
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
top_p=top_p,
tokenizer=tokenizer # Pass tokenizer to the model_inference function
)
# Update history with the user input and model response
history.append((user_input, response))
# Return the response and updated history
return response, history
# Define the Gradio interface components
interface = gr.Interface(
fn=chat_interface,
inputs=[
gr.Textbox(label="User Input", placeholder="Type your message here..."),
gr.State([]), # Initialize the chat history as an empty list
gr.Checkbox(label="Perform Web Search"),
gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"),
gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5),
gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048),
gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1),
gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9)
],
outputs=[
gr.Textbox(label="Assistant Response"),
gr.State([]) # Update the chat history
],
live=True
)
# Launch the Gradio interface
interface.launch()