LLAMA / app.py
ariankhalfani's picture
Create app.py
ee1461b verified
raw
history blame
4.36 kB
import os
import sqlite3
import requests
import PyPDF2
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import gradio as gr
# Configure Hugging Face API
huggingface_api_url = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct"
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
# Function to query Hugging Face model
def query_huggingface(payload):
response = requests.post(huggingface_api_url, headers=headers, json=payload)
return response.json()
# Function to extract text from PDF
def extract_text_from_pdf(pdf_file):
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
text += page.extract_text()
return text
# Initialize SQLite database
def init_db():
conn = sqlite3.connect('storage_warehouse.db')
c = conn.cursor()
c.execute('''
CREATE TABLE IF NOT EXISTS context (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
content TEXT
)
''')
conn.commit()
conn.close()
# Add context to the database
def add_context(name, content):
conn = sqlite3.connect('storage_warehouse.db')
c = conn.cursor()
c.execute('INSERT INTO context (name, content) VALUES (?, ?)', (name, content))
conn.commit()
conn.close()
# Retrieve context from the database
def get_context():
conn = sqlite3.connect('storage_warehouse.db')
c = conn.cursor()
c.execute('SELECT content FROM context')
context = c.fetchall()
conn.close()
return [c[0] for c in context]
# Function to create or update the FAISS index
def update_faiss_index():
contexts = get_context()
embeddings = model.encode(contexts, convert_to_tensor=True)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.cpu().numpy())
return index, contexts
# Retrieve relevant context from the FAISS index
def retrieve_relevant_context(index, contexts, query, top_k=5):
query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
distances, indices = index.search(query_embedding, top_k)
relevant_contexts = [contexts[i] for i in indices[0]]
return relevant_contexts
# Initialize the database and FAISS model
init_db()
model = SentenceTransformer('all-MiniLM-L6-v2')
faiss_index, context_list = update_faiss_index()
# Function to handle chatbot responses
def chatbot_response(question):
relevant_contexts = retrieve_relevant_context(faiss_index, context_list, question)
user_input = f"question: {question} context: {' '.join(relevant_contexts)}"
response = query_huggingface({"inputs": user_input})
response_text = response.get("generated_text", "Sorry, I couldn't generate a response.")
return response_text
# Function to handle PDF uploads
def handle_pdf_upload(pdf_file):
context = extract_text_from_pdf(pdf_file)
add_context(pdf_file.name, context)
faiss_index, context_list = update_faiss_index() # Update FAISS index
return f"Context from {pdf_file.name} added to the database."
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Storage Warehouse Customer Service Chatbot")
with gr.Row():
with gr.Column(scale=4):
with gr.Box():
pdf_upload = gr.File(label="Upload PDF", file_types=["pdf"], interactive=True)
upload_button = gr.Button("Upload")
upload_status = gr.Textbox(label="Upload Status")
def handle_upload(files):
for file in files:
result = handle_pdf_upload(file.name)
upload_status.value = result
upload_button.click(fn=handle_upload, inputs=pdf_upload, outputs=upload_status)
with gr.Column(scale=8):
chatbot = gr.Chatbot(label="Chatbot")
question = gr.Textbox(label="Your question here:")
submit_button = gr.Button("Submit")
def handle_chat(user_input):
bot_response = chatbot_response(user_input)
return gr.Chatbot.update([[user_input, bot_response]])
submit_button.click(fn=handle_chat, inputs=question, outputs=chatbot)
if __name__ == "__main__":
demo.launch()