Spaces:
Runtime error
Runtime error
import os | |
import openai | |
import gradio as gr | |
import pdfplumber | |
import boto3 | |
from llama_index.core import Document, VectorStoreIndex, Settings | |
from llama_index.llms.openai import OpenAI | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
from llama_index.core.postprocessor import MetadataReplacementPostProcessor | |
from llama_index.core.node_parser import SentenceWindowNodeParser | |
from dotenv import load_dotenv | |
load_dotenv("config.env") | |
# Set your OpenAI API key here | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
openai.api_key = OPENAI_API_KEY | |
# AWS S3 setup | |
s3_bucket_name = "sagemaker-studio-gm4vm5dimae" | |
s3_client = boto3.client('s3') | |
# Directory to store downloaded PDFs | |
resume_path = 'resumes' | |
os.makedirs(resume_path, exist_ok=True) | |
# Function to download PDFs from S3 | |
def download_pdfs_from_s3(bucket_name, local_path): | |
objects = s3_client.list_objects_v2(Bucket=bucket_name) | |
for obj in objects.get('Contents', []): | |
file_name = obj['Key'] | |
local_file_path = os.path.join(local_path, file_name) | |
s3_client.download_file(bucket_name, file_name, local_file_path) | |
print(f"Downloaded {file_name} to {local_file_path}") | |
# Download PDFs | |
download_pdfs_from_s3(s3_bucket_name, resume_path) | |
# Function to load PDFs using pdfplumber | |
def load_pdfs_with_pdfplumber(directory): | |
documents = [] | |
for filename in os.listdir(directory): | |
if filename.endswith(".pdf"): | |
try: | |
with pdfplumber.open(os.path.join(directory, filename)) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() or "" | |
documents.append(Document(text=text)) | |
except Exception as e: | |
print(f"Error processing {filename}: {e}") | |
return documents | |
# Load documents from the resume directory using pdfplumber | |
documents = load_pdfs_with_pdfplumber(resume_path) | |
print(f"Number of documents: {len(documents)}") | |
# Set up the LLM (GPT-4o) | |
llm = OpenAI(model="gpt-4o", temperature=0.9) | |
# Set up the embedding model | |
embed_model = OpenAIEmbedding(model="text-embedding-3-large") | |
# Create sentence window node parser with default settings | |
sentence_node_parser = SentenceWindowNodeParser.from_defaults( | |
window_size=3, | |
window_metadata_key="window", | |
original_text_metadata_key="original_text" | |
) | |
# Configure global settings | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
Settings.node_parser = sentence_node_parser | |
# Create index | |
index = VectorStoreIndex.from_documents(documents) | |
# Custom reranking function | |
def custom_rerank(nodes, query): | |
rerank_prompt = ( | |
"Given the following query and text chunks, rate each chunk's relevance " | |
"to the query on a scale of 1-10, where 10 is most relevant.\n\n" | |
f"Query: {query}\n\n" | |
) | |
for i, node in enumerate(nodes): | |
rerank_prompt += f"Chunk {i+1}:\n{node.get_content()}\n\n" | |
rerank_prompt += "Provide your ratings as a comma-separated list of numbers, e.g., '7,4,9,2,6'" | |
response = llm.complete(rerank_prompt) | |
try: | |
ratings = [int(r.strip()) for r in response.text.split(',')] | |
if len(ratings) != len(nodes): | |
raise ValueError("Number of ratings does not match number of nodes") | |
sorted_nodes = [node for _, node in sorted(zip(ratings, nodes), key=lambda x: x[0], reverse=True)] | |
return sorted_nodes[:5] # Return top 5 reranked nodes | |
except Exception as e: | |
print(f"Error in reranking: {e}, returning original order") | |
return nodes[:5] | |
# Create query engine | |
query_engine = index.as_query_engine( | |
similarity_top_k=20, | |
node_postprocessors=[ | |
MetadataReplacementPostProcessor("window") | |
], | |
) | |
# Chatbot function | |
def chatbot(message, history): | |
history_text = "\n".join([f"Human: {h[0]}\nAI: {h[1]}" for h in history]) | |
full_query = f"Given the following chat history:\n{history_text}\n\nHuman: {message}\nAI:" | |
# Retrieve nodes | |
retrieved_nodes = query_engine.retrieve(full_query) | |
# Apply custom reranking | |
reranked_nodes = custom_rerank(retrieved_nodes, full_query) | |
# Synthesize answer from reranked nodes | |
context = "\n".join([node.get_content() for node in reranked_nodes]) | |
response = llm.complete( | |
f"Using the following context, answer the query:\n\nContext: {context}\n\nQuery: {full_query}" | |
) | |
return response.text | |
# Create Gradio interface | |
iface = gr.ChatInterface( | |
chatbot, | |
title="Resume Chatbot", | |
description="Ask questions about resumes in the database.", | |
theme="soft", | |
examples=[ | |
"Out of all the resumes tell me three of them who have experience in SQL?", | |
"Give me key summary takeaways of the resumes who have experience in Project Management?", | |
"Give me the names of 10 candidates who have more than two years of experience in general?", | |
], | |
retry_btn=None, | |
undo_btn="Delete Previous", | |
clear_btn="Clear", | |
chatbot=gr.Chatbot(height=400), | |
textbox=gr.Textbox(scale=5) | |
) | |
# Launch the interface | |
iface.launch(share=True) | |