File size: 3,837 Bytes
d745fdc
ee1461b
 
 
5e8012a
ee1461b
 
 
 
 
d745fdc
4fe5e8e
ee1461b
 
 
 
d745fdc
 
ee1461b
 
 
 
 
5e8012a
 
 
 
ee1461b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e8012a
 
 
ee1461b
 
 
 
 
 
 
5e8012a
 
 
ee1461b
 
 
 
 
 
 
 
 
 
9874be5
d745fdc
ee1461b
 
d745fdc
9874be5
ee1461b
 
5e8012a
 
 
 
 
 
 
ee1461b
5e8012a
9874be5
 
d745fdc
173d79d
9874be5
 
173d79d
ee1461b
5e8012a
9874be5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from huggingface_hub import InferenceClient
import os
import sqlite3
import requests
import fitz  # PyMuPDF
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import gradio as gr

# Configure Hugging Face API URL and headers
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY")
headers = {"Authorization": f"Bearer {huggingface_api_key}"}

# Function to query Hugging Face model
def query_huggingface(payload):
    response = requests.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers, json=payload)
    return response.json()

# Function to extract text from PDF
def extract_text_from_pdf(pdf_file):
    text = ""
    pdf_document = fitz.open(stream=pdf_file.read(), filetype="pdf")
    for page_num in range(len(pdf_document)):
        page = pdf_document.load_page(page_num)
        text += page.get_text()
    return text

# Initialize SQLite database
def init_db():
    conn = sqlite3.connect('storage_warehouse.db')
    c = conn.cursor()
    c.execute('''
        CREATE TABLE IF NOT EXISTS context (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT,
            content TEXT
        )
    ''')
    conn.commit()
    conn.close()

# Add context to the database
def add_context(name, content):
    conn = sqlite3.connect('storage_warehouse.db')
    c = conn.cursor()
    c.execute('INSERT INTO context (name, content) VALUES (?, ?)', (name, content))
    conn.commit()
    conn.close()

# Retrieve context from the database
def get_context():
    conn = sqlite3.connect('storage_warehouse.db')
    c = conn.cursor()
    c.execute('SELECT content FROM context')
    context = c.fetchall()
    conn.close()
    return [c[0] for c in context]

# Function to create or update the FAISS index
def update_faiss_index():
    contexts = get_context()
    if len(contexts) == 0:
        return None, contexts

    embeddings = model.encode(contexts, convert_to_tensor=True)
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings.cpu().numpy())
    return index, contexts

# Retrieve relevant context from the FAISS index
def retrieve_relevant_context(index, contexts, query, top_k=5):
    if index is None or len(contexts) == 0:
        return []
    
    query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
    distances, indices = index.search(query_embedding, top_k)
    relevant_contexts = [contexts[i] for i in indices[0]]
    return relevant_contexts

# Initialize the database and FAISS model
init_db()
model = SentenceTransformer('all-MiniLM-L6-v2')
faiss_index, context_list = update_faiss_index()

# Gradio interface for chatbot
def chatbot(question):
    relevant_contexts = retrieve_relevant_context(faiss_index, context_list, question)
    user_input = f"question: {question} context: {' '.join(relevant_contexts)}"
    response = query_huggingface({"inputs": user_input})
    response_text = response[0].get("generated_text", "Sorry, I couldn't generate a response.") if isinstance(response, list) else response.get("generated_text", "Sorry, I couldn't generate a response.")
    return response_text

# File upload function
def upload_pdf(file):
    context = extract_text_from_pdf(file)
    add_context(file.name, context)
    global faiss_index, context_list
    faiss_index, context_list = update_faiss_index()
    return "PDF content added to context."

# Gradio interface
iface = gr.Interface(
    fn=chatbot,
    inputs=gr.Textbox(),
    outputs=gr.Textbox(),
    title="Storage Warehouse Customer Service Chatbot"
)
file_upload = gr.Interface(fn=upload_pdf, inputs=gr.File(), outputs=gr.Textbox(), title="Upload PDF for Context")

app = gr.TabbedInterface([iface, file_upload], ["Chatbot", "Upload PDF"])
app.launch(share=True)