SearchGPT / app.py
Shreyas094's picture
Update app.py
6228a67 verified
raw
history blame
14.9 kB
import fitz # PyMuPDF
import gradio as gr
import requests
from bs4 import BeautifulSoup
import urllib.parse
import random
import os
from dotenv import load_dotenv
import shutil
import tempfile
load_dotenv() # Load environment variables from .env file
# Now replace the hard-coded token with the environment variable
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
def clear_cache():
try:
# Clear Gradio cache
cache_dir = tempfile.gettempdir()
shutil.rmtree(os.path.join(cache_dir, "gradio"), ignore_errors=True)
# Clear any custom cache you might have
# For example, if you're caching PDF files or search results:
if os.path.exists("output_summary.pdf"):
os.remove("output_summary.pdf")
# Add any other cache clearing operations here
print("Cache cleared successfully.")
return "Cache cleared successfully."
except Exception as e:
print(f"Error clearing cache: {e}")
return f"Error clearing cache: {e}"
_useragent_list = [
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
]
# Function to extract visible text from HTML content of a webpage
def extract_text_from_webpage(html):
print("Extracting text from webpage...")
soup = BeautifulSoup(html, 'html.parser')
for script in soup(["script", "style"]):
script.extract() # Remove scripts and styles
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n'.join(chunk for chunk in chunks if chunk)
print(f"Extracted text length: {len(text)}")
return text
# Function to perform a Google search and retrieve results
def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
"""Performs a Google search and returns the results."""
print(f"Searching for term: {term}")
escaped_term = urllib.parse.quote_plus(term)
start = 0
all_results = []
max_chars_per_page = 8000 # Limit the number of characters from each webpage to stay under the token limit
with requests.Session() as session:
while start < num_results:
print(f"Fetching search results starting from: {start}")
try:
# Choose a random user agent
user_agent = random.choice(_useragent_list)
headers = {
'User-Agent': user_agent
}
print(f"Using User-Agent: {headers['User-Agent']}")
resp = session.get(
url="https://www.google.com/search",
headers=headers,
params={
"q": term,
"num": num_results - start,
"hl": lang,
"start": start,
"safe": safe,
},
timeout=timeout,
verify=ssl_verify,
)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
print(f"Error fetching search results: {e}")
break
soup = BeautifulSoup(resp.text, "html.parser")
result_block = soup.find_all("div", attrs={"class": "g"})
if not result_block:
print("No more results found.")
break
for result in result_block:
link = result.find("a", href=True)
if link:
link = link["href"]
print(f"Found link: {link}")
try:
webpage = session.get(link, headers=headers, timeout=timeout)
webpage.raise_for_status()
visible_text = extract_text_from_webpage(webpage.text)
if len(visible_text) > max_chars_per_page:
visible_text = visible_text[:max_chars_per_page] + "..."
all_results.append({"link": link, "text": visible_text})
except requests.exceptions.RequestException as e:
print(f"Error fetching or processing {link}: {e}")
all_results.append({"link": link, "text": None})
else:
print("No link found in result.")
all_results.append({"link": None, "text": None})
start += len(result_block)
print(f"Total results fetched: {len(all_results)}")
return all_results
# Function to format the prompt for the Hugging Face API
def format_prompt(query, search_results, instructions):
formatted_results = ""
for result in search_results:
link = result["link"]
text = result["text"]
if link:
formatted_results += f"URL: {link}\nContent: {text}\n{'-' * 80}\n"
else:
formatted_results += "No link found.\n" + '-' * 80 + '\n'
prompt = f"{instructions}User Query: {query}\n\nWeb Search Results:\n{formatted_results}\n\nAssistant:"
return prompt
# Function to generate text using Hugging Face API
def generate_text(input_text, temperature=0.7, repetition_penalty=1.0, top_p=0.9):
print("Generating text using Hugging Face API...")
endpoint = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
headers = {
"Authorization": f"Bearer {HUGGINGFACE_API_TOKEN}", # Use the environment variable
"Content-Type": "application/json"
}
data = {
"inputs": input_text,
"parameters": {
"max_new_tokens": 8000, # Adjust as needed
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"top_p": top_p
}
}
try:
response = requests.post(endpoint, headers=headers, json=data)
response.raise_for_status()
# Check if response is JSON
try:
json_data = response.json()
except ValueError:
print("Response is not JSON.")
return None
# Extract generated text from response JSON
if isinstance(json_data, list):
# Handle list response (if applicable for your use case)
generated_text = json_data[0].get("generated_text") if json_data else None
elif isinstance(json_data, dict):
# Handle dictionary response
generated_text = json_data.get("generated_text")
else:
print("Unexpected response format.")
return None
if generated_text is not None:
print("Text generation complete using Hugging Face API.")
print(f"Generated text: {generated_text}") # Debugging line
return generated_text
else:
print("Generated text not found in response.")
return None
except requests.exceptions.RequestException as e:
print(f"Error generating text using Hugging Face API: {e}")
return None
# Function to read and extract text from a PDF
def read_pdf(file_obj):
with fitz.open(file_obj.name) as document:
text = ""
for page_num in range(document.page_count):
page = document.load_page(page_num)
text += page.get_text()
return text
# Function to format the prompt with instructions for text generation
def format_prompt_with_instructions(text, instructions):
prompt = f"{instructions}{text}\n\nAssistant:"
return prompt
# Function to save text to a PDF
def save_text_to_pdf(text, output_path):
print(f"Saving text to PDF at {output_path}...")
doc = fitz.open() # Create a new PDF document
page = doc.new_page() # Create a new page
# Set the page margins
margin = 50 # 50 points margin
page_width = page.rect.width
page_height = page.rect.height
text_width = page_width - 2 * margin
text_height = page_height - 2 * margin
# Define font size and line spacing
font_size = 9
line_spacing = 1 * font_size
fontname = "times-roman" # Use a supported font name
# Process the text into lines that fit within the text_width
lines = []
current_line = ""
current_line_width = 0
words = text.split(" ")
for word in words:
word_width = fitz.get_text_length(word, fontname, font_size)
if current_line_width + word_width <= text_width:
current_line += word + " "
current_line_width += word_width + fitz.get_text_length(" ", fontname, font_size)
else:
lines.append(current_line.strip())
current_line = word + " "
current_line_width = word_width + fitz.get_text_length(" ", fontname, font_size)
if current_line:
lines.append(current_line.strip())
# Add the lines to the page with margins
x = margin
y = margin
for line in lines:
if y + line_spacing > text_height:
# Create a new page if text exceeds the page height
page = doc.new_page()
y = margin # Reset y-coordinate for the new page
page.insert_text((x, y), line, fontname=fontname, fontsize=font_size)
y += line_spacing
doc.save(output_path) # Save the PDF to the specified output path
print(f"Text saved to PDF at {output_path}")
# Function to process the PDF or search query and generate a summary
def process_input(query_or_file, is_pdf, instructions, temperature, top_p, repetition_penalty):
load_dotenv() # Load environment variables from .env file
HUGGINGFACE_API_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if is_pdf:
print(f"Processing PDF: {query_or_file.name}")
input_text = read_pdf(query_or_file)
else:
print(f"Processing search query: {query_or_file}")
search_results = google_search(query_or_file)
input_text = "\n\n".join(result["text"] for result in search_results if result["text"])
# Split the input text into smaller chunks to fit within the token limit
chunk_size = 1024 # Adjust as needed to stay within the token limit
text_chunks = [input_text[i:i + chunk_size] for i in range(0, len(input_text), chunk_size)]
print(f"Total number of chunks: {len(text_chunks)}")
# Generate summaries for each chunk and concatenate them
concatenated_summary = ""
for chunk in text_chunks:
prompt = format_prompt_with_instructions(chunk, instructions)
chunk_summary = generate_text(prompt, temperature, repetition_penalty, top_p)
concatenated_summary += f"{chunk_summary}\n\n"
print("Final concatenated summary generated.")
return concatenated_summary
# Function to clear cache
def clear_cache():
try:
# Clear Gradio cache
cache_dir = tempfile.gettempdir()
shutil.rmtree(os.path.join(cache_dir, "gradio"), ignore_errors=True)
# Clear any custom cache you might have
# For example, if you're caching PDF files or search results:
if os.path.exists("output_summary.pdf"):
os.remove("output_summary.pdf")
# Add any other cache clearing operations here
print("Cache cleared successfully.")
return "Cache cleared successfully."
except Exception as e:
print(f"Error clearing cache: {e}")
return f"Error clearing cache: {e}"
def summarization_interface():
with gr.Blocks() as demo:
gr.Markdown("# PDF and Web Summarization Tool")
with gr.Tab("Summarize PDF"):
pdf_file = gr.File(label="Upload PDF", file_types=[".pdf"])
pdf_instructions = gr.Textbox(label="Instructions for Summarization", placeholder="Enter instructions for summarization", lines=3)
pdf_temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.01)
pdf_top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
pdf_repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, value=1.0, step=0.1)
pdf_summary_output = gr.Textbox(label="Concatenated Summary Output")
pdf_summarize_button = gr.Button("Generate Summary")
pdf_clear_cache_button = gr.Button("Clear Cache")
with gr.Tab("Summarize Web Search"):
search_query = gr.Textbox(label="Enter Search Query", placeholder="Enter search query")
search_instructions = gr.Textbox(label="Instructions for Summarization", placeholder="Enter instructions for summarization", lines=3)
search_temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.7, step=0.01)
search_top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.01)
search_repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, value=1.0, step=0.1)
search_summary_output = gr.Textbox(label="Concatenated Summary Output")
search_summarize_button = gr.Button("Generate Summary")
search_clear_cache_button = gr.Button("Clear Cache")
# Bind functions to button clicks
pdf_summarize_button.click(
fn=lambda file, instructions, temperature, top_p, repetition_penalty: generate_and_save_summary(file, True, instructions, temperature, top_p, repetition_penalty),
inputs=[pdf_file, pdf_instructions, pdf_temperature, pdf_top_p, pdf_repetition_penalty],
outputs=[pdf_summary_output]
)
search_summarize_button.click(
fn=lambda query, instructions, temperature, top_p, repetition_penalty: generate_and_save_summary(query, False, instructions, temperature, top_p, repetition_penalty),
inputs=[search_query, search_instructions, search_temperature, search_top_p, search_repetition_penalty],
outputs=[search_summary_output]
)
pdf_clear_cache_button.click(fn=clear_cache, inputs=None, outputs=pdf_summary_output)
search_clear_cache_button.click(fn=clear_cache, inputs=None, outputs=search_summary_output)
return demo
# Launch the Gradio interface
demo = summarization_interface()
demo.launch()