import os import urllib import requests from bs4 import BeautifulSoup import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import logging import feedparser # Set up logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Define device and load model and tokenizer DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" # Load model and tokenizer try: logger.debug("Attempting to load the model and tokenizer") model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) logger.debug("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Error loading model and tokenizer: {e}") model = None tokenizer = None # Function to fetch news from Google News RSS feed def fetch_news(term, num_results=2): logger.debug(f"Fetching news for term: {term}") encoded_term = urllib.parse.quote(term) url = f"https://news.google.com/rss/search?q={encoded_term}" feed = feedparser.parse(url) results = [] for entry in feed.entries[:num_results]: results.append({"link": entry.link, "text": entry.title}) logger.debug(f"Fetched news results: {results}") return results # Function to perform a Google search and return the results def search(term, num_results=2, lang="en", timeout=5, safe="active", ssl_verify=None): logger.debug(f"Starting search for term: {term}") escaped_term = urllib.parse.quote_plus(term) start = 0 all_results = [] max_chars_per_page = 8000 with requests.Session() as session: while start < num_results: try: resp = session.get( url="https://www.google.com/search", headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, params={ "q": term, "num": num_results - start, "hl": lang, "start": start, "safe": safe, }, timeout=timeout, verify=ssl_verify, ) resp.raise_for_status() soup = BeautifulSoup(resp.text, "html.parser") result_block = soup.find_all("div", attrs={"class": "g"}) if not result_block: start += 1 continue for result in result_block: link = result.find("a", href=True) if link: link = link["href"] try: webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}) webpage.raise_for_status() visible_text = extract_text_from_webpage(webpage.text) if len(visible_text) > max_chars_per_page: visible_text = visible_text[:max_chars_per_page] + "..." all_results.append({"link": link, "text": visible_text}) except requests.exceptions.RequestException as e: logger.error(f"Error fetching or processing {link}: {e}") all_results.append({"link": link, "text": None}) else: all_results.append({"link": None, "text": None}) start += len(result_block) except Exception as e: logger.error(f"Error during search: {e}") break logger.debug(f"Search results: {all_results}") return all_results # Function to extract visible text from HTML content def extract_text_from_webpage(html_content): soup = BeautifulSoup(html_content, "html.parser") for tag in soup(["script", "style", "header", "footer", "nav"]): tag.extract() visible_text = soup.get_text(strip=True) return visible_text # Function to format the prompt for the language model def format_prompt(user_prompt, chat_history): logger.debug(f"Formatting prompt with user prompt: {user_prompt} and chat history: {chat_history}") prompt = "" for item in chat_history: prompt += f"User: {item[0]}\nAssistant: {item[1]}\n" prompt += f"User: {user_prompt}\nAssistant:" logger.debug(f"Formatted prompt: {prompt}") return prompt # Function for model inference def model_inference( user_prompt, chat_history, web_search, temperature, max_new_tokens, repetition_penalty, top_p, tokenizer # Pass tokenizer as an argument ): logger.debug(f"Starting model inference with user prompt: {user_prompt}, chat history: {chat_history}, web_search: {web_search}") if not isinstance(user_prompt, dict): logger.error("Invalid input format. Expected a dictionary.") return "Invalid input format. Expected a dictionary." if "files" not in user_prompt: user_prompt["files"] = [] if not user_prompt["files"]: if web_search: logger.debug("Performing news search") news_results = fetch_news(user_prompt["text"]) news2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in news_results]) formatted_prompt = format_prompt(f"{user_prompt['text']} [NEWS] {news2}", chat_history) inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) if model: outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, do_sample=True, temperature=temperature, top_p=top_p ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) else: response = "Model is not available. Please try again later." logger.debug(f"Model response: {response}") return response else: formatted_prompt = format_prompt(user_prompt["text"], chat_history) inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) if model: outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, do_sample=True, temperature=temperature, top_p=top_p ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) else: response = "Model is not available. Please try again later." logger.debug(f"Model response: {response}") return response else: return "Image input not supported in this implementation." # Define Gradio interface components max_new_tokens = gr.Slider( minimum=1, maximum=16000, value=2048, step=64, interactive=True, label="Maximum number of new tokens to generate", ) repetition_penalty = gr.Slider( minimum=0.01, maximum=5.0, value=1, step=0.01, interactive=True, label="Repetition penalty", info="1.0 is equivalent to no penalty", ) decoding_strategy = gr.Radio( [ "Greedy", "Top P Sampling", ], value="Top P Sampling", label="Decoding strategy", interactive=True, info="Higher values are equivalent to sampling more low-probability tokens.", ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.5, step=0.05, visible=True, interactive=True, label="Sampling temperature", info="Higher values will produce more diverse outputs.", ) top_p = gr.Slider( minimum=0.01, maximum=0.99, value=0.9, step=0.01, visible=True, interactive=True, label="Top P", info="Higher values are equivalent to sampling more low-probability tokens.", ) # Create a chatbot interface chatbot = gr.Chatbot( label="OpenGPT-4o-Chatty", show_copy_button=True, likeable=True, layout="panel" ) # Define Gradio interface def chat_interface(user_input, history, web_search, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p): # Ensure the tokenizer is accessible within the function scope global tokenizer # Wrap the user input in a dictionary as expected by the model_inference function user_prompt = {"text": user_input, "files": []} # Perform model inference response = model_inference( user_prompt=user_prompt, chat_history=history, web_search=web_search, temperature=temperature, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, top_p=top_p, tokenizer=tokenizer # Pass tokenizer to the model_inference function ) # Update history with the user input and model response history.append((user_input, response)) # Return the response and updated history return response, history # Define the Gradio interface components interface = gr.Interface( fn=chat_interface, inputs=[ gr.Textbox(label="User Input", placeholder="Type your message here..."), gr.State([]), # Initialize the chat history as an empty list gr.Checkbox(label="Perform Web Search"), gr.Radio(["Greedy", "Top P Sampling"], label="Decoding strategy"), gr.Slider(minimum=0.0, maximum=2.0, step=0.05, label="Sampling temperature", value=0.5), gr.Slider(minimum=1, maximum=16000, step=64, label="Maximum number of new tokens to generate", value=2048), gr.Slider(minimum=0.01, maximum=5.0, step=0.01, label="Repetition penalty", value=1), gr.Slider(minimum=0.01, maximum=0.99, step=0.01, label="Top P", value=0.9) ], outputs=[ gr.Textbox(label="Assistant Response"), gr.State([]) # Update the chat history ], live=True ) # Launch the Gradio interface interface.launch()