Spaces:
Runtime error
Runtime error
import os | |
import warnings | |
# Ensure no GPU is used by setting the environment variable | |
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
# Disable ZeroGPU if running in Hugging Face's environment | |
# os.environ["HF_USE_ZeroGPU"] = "false" | |
# Suppress NVML initialization warning | |
warnings.filterwarnings("ignore", message="Can't initialize NVML") | |
import gradio as gr | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.docstore.document import Document | |
from langchain.llms import HuggingFacePipeline | |
from langchain.chains import RetrievalQA | |
from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM | |
import torch | |
import re | |
import transformers | |
from torch import bfloat16 | |
from langchain_community.document_loaders import DirectoryLoader | |
import spaces | |
# Initialize embeddings and ChromaDB | |
model_name = "sentence-transformers/all-mpnet-base-v2" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_kwargs = {"device": device} | |
embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs) | |
# loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True) | |
loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
all_splits = text_splitter.split_documents(docs) | |
vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db") | |
books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings) | |
books_db_client = books_db.as_retriever() | |
# Initialize the model and tokenizer | |
model_name = "unsloth/Llama-3.2-3B-Instruct" | |
# bnb_config = transformers.BitsAndBytesConfig( | |
# load_in_4bit=True, | |
# bnb_4bit_quant_type='nf4', | |
# bnb_4bit_use_double_quant=True, | |
# bnb_4bit_compute_dtype=torch.bfloat16 | |
# ) | |
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
config=model_config, | |
# quantization_config=bnb_config, | |
device_map="auto" if device == "cuda" else None, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
query_pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
return_full_text=True, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map="auto" if device == "cuda" else None, | |
temperature=0.3, | |
top_p=0.8, | |
top_k=50, | |
repetition_penalty=1.2, | |
max_new_tokens=128 | |
) | |
llm = HuggingFacePipeline(pipeline=query_pipeline) | |
books_db_client_retriever = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=books_db_client, | |
verbose=True | |
) | |
# Function to retrieve answer using the RAG system | |
def test_rag(query): | |
rag_query = f"You are an AI assistant with access to a knowledge base containing books and other materials. {query} Summarize the retrieved answer for this query. Retrieve information only from the provided knowledge base. If no relevant information is found in the knowledge base, do not provide placeholder answers. Respond only with clear and concise answers based on the available knowledge." | |
books_retriever = books_db_client_retriever.run(rag_query) | |
corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL) | |
if corrected_text_match: | |
return corrected_text_match.group(1).strip() | |
else: | |
return "No helpful answer found." | |
# Gradio interface | |
def respond(message, history): | |
response = test_rag(message) | |
return response | |
iface = gr.ChatInterface( | |
respond, | |
chatbot=gr.Chatbot(height=700), | |
textbox=gr.Textbox(placeholder="Ask me anything about the content of the PDF(s):", container=False, scale=7), | |
title="RAG Chatbot", | |
cache_examples=True, | |
) | |
iface.launch() | |