Spaces:
Runtime error
Runtime error
Govind
commited on
Commit
•
04f126d
1
Parent(s):
8bd3390
created app.py
Browse files
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
# Ensure no GPU is used by setting the environment variable
|
5 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
6 |
+
|
7 |
+
# Disable ZeroGPU if running in Hugging Face's environment
|
8 |
+
# os.environ["HF_USE_ZeroGPU"] = "false"
|
9 |
+
|
10 |
+
# Suppress NVML initialization warning
|
11 |
+
warnings.filterwarnings("ignore", message="Can't initialize NVML")
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
15 |
+
from langchain.vectorstores import Chroma
|
16 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
+
from langchain.docstore.document import Document
|
18 |
+
from langchain.llms import HuggingFacePipeline
|
19 |
+
from langchain.chains import RetrievalQA
|
20 |
+
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
|
21 |
+
import torch
|
22 |
+
import re
|
23 |
+
import transformers
|
24 |
+
from torch import bfloat16
|
25 |
+
from langchain_community.document_loaders import DirectoryLoader
|
26 |
+
import spaces
|
27 |
+
|
28 |
+
# Initialize embeddings and ChromaDB
|
29 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
model_kwargs = {"device": device}
|
32 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
|
33 |
+
|
34 |
+
# loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True)
|
35 |
+
loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True)
|
36 |
+
docs = loader.load()
|
37 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
38 |
+
all_splits = text_splitter.split_documents(docs)
|
39 |
+
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db")
|
40 |
+
books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings)
|
41 |
+
|
42 |
+
books_db_client = books_db.as_retriever()
|
43 |
+
|
44 |
+
# Initialize the model and tokenizer
|
45 |
+
model_name = "unsloth/Llama-3.2-3B-Instruct"
|
46 |
+
|
47 |
+
# bnb_config = transformers.BitsAndBytesConfig(
|
48 |
+
# load_in_4bit=True,
|
49 |
+
# bnb_4bit_quant_type='nf4',
|
50 |
+
# bnb_4bit_use_double_quant=True,
|
51 |
+
# bnb_4bit_compute_dtype=torch.bfloat16
|
52 |
+
# )
|
53 |
+
|
54 |
+
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
|
55 |
+
|
56 |
+
|
57 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
58 |
+
model_name,
|
59 |
+
trust_remote_code=True,
|
60 |
+
config=model_config,
|
61 |
+
# quantization_config=bnb_config,
|
62 |
+
device_map="auto" if device == "cuda" else None,
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
67 |
+
|
68 |
+
query_pipeline = transformers.pipeline(
|
69 |
+
"text-generation",
|
70 |
+
model=model,
|
71 |
+
tokenizer=tokenizer,
|
72 |
+
return_full_text=True,
|
73 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
74 |
+
device_map="auto" if device == "cuda" else None,
|
75 |
+
temperature=0.3,
|
76 |
+
top_p=0.8,
|
77 |
+
top_k=50,
|
78 |
+
repetition_penalty=1.2,
|
79 |
+
max_new_tokens=128
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
llm = HuggingFacePipeline(pipeline=query_pipeline)
|
84 |
+
|
85 |
+
books_db_client_retriever = RetrievalQA.from_chain_type(
|
86 |
+
llm=llm,
|
87 |
+
chain_type="stuff",
|
88 |
+
retriever=books_db_client,
|
89 |
+
verbose=True
|
90 |
+
)
|
91 |
+
|
92 |
+
# Function to retrieve answer using the RAG system
|
93 |
+
@spaces.GPU(duration=60)
|
94 |
+
def test_rag(query):
|
95 |
+
rag_query = f"You are an AI assistant with access to books knowledge.{query} Retrieve information only from the knowledge base provided. If you don't find relevant information in the knowledge base, do not respond with placeholder answers. Provide only clear and concise answers based on available knowledge."
|
96 |
+
books_retriever = books_db_client_retriever.run(rag_query)
|
97 |
+
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
|
98 |
+
if corrected_text_match:
|
99 |
+
return corrected_text_match.group(1).strip()
|
100 |
+
else:
|
101 |
+
return "No helpful answer found."
|
102 |
+
|
103 |
+
# Gradio interface
|
104 |
+
def respond(message, history):
|
105 |
+
response = test_rag(message)
|
106 |
+
return response
|
107 |
+
|
108 |
+
iface = gr.ChatInterface(
|
109 |
+
respond,
|
110 |
+
chatbot=gr.Chatbot(height=700),
|
111 |
+
textbox=gr.Textbox(placeholder="Ask me anything about the content of the PDF(s):", container=False, scale=7),
|
112 |
+
title="RAG Chatbot",
|
113 |
+
cache_examples=True,
|
114 |
+
)
|
115 |
+
|
116 |
+
iface.launch()
|