Govind commited on
Commit
04f126d
1 Parent(s): 8bd3390

created app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ # Ensure no GPU is used by setting the environment variable
5
+ # os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
6
+
7
+ # Disable ZeroGPU if running in Hugging Face's environment
8
+ # os.environ["HF_USE_ZeroGPU"] = "false"
9
+
10
+ # Suppress NVML initialization warning
11
+ warnings.filterwarnings("ignore", message="Can't initialize NVML")
12
+
13
+ import gradio as gr
14
+ from langchain.embeddings import HuggingFaceEmbeddings
15
+ from langchain.vectorstores import Chroma
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain.docstore.document import Document
18
+ from langchain.llms import HuggingFacePipeline
19
+ from langchain.chains import RetrievalQA
20
+ from transformers import AutoConfig, AutoTokenizer, pipeline, AutoModelForCausalLM
21
+ import torch
22
+ import re
23
+ import transformers
24
+ from torch import bfloat16
25
+ from langchain_community.document_loaders import DirectoryLoader
26
+ import spaces
27
+
28
+ # Initialize embeddings and ChromaDB
29
+ model_name = "sentence-transformers/all-mpnet-base-v2"
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ model_kwargs = {"device": device}
32
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
33
+
34
+ # loader = DirectoryLoader('./pdf', glob="**/*.pdf", use_multithreading=True)
35
+ loader = DirectoryLoader('./pdf', glob="**/*.pdf", recursive=True, use_multithreading=True)
36
+ docs = loader.load()
37
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
38
+ all_splits = text_splitter.split_documents(docs)
39
+ vectordb = Chroma.from_documents(documents=all_splits, embedding=embeddings, persist_directory="pdf_db")
40
+ books_db = Chroma(persist_directory="./pdf_db", embedding_function=embeddings)
41
+
42
+ books_db_client = books_db.as_retriever()
43
+
44
+ # Initialize the model and tokenizer
45
+ model_name = "unsloth/Llama-3.2-3B-Instruct"
46
+
47
+ # bnb_config = transformers.BitsAndBytesConfig(
48
+ # load_in_4bit=True,
49
+ # bnb_4bit_quant_type='nf4',
50
+ # bnb_4bit_use_double_quant=True,
51
+ # bnb_4bit_compute_dtype=torch.bfloat16
52
+ # )
53
+
54
+ model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
55
+
56
+
57
+ model = transformers.AutoModelForCausalLM.from_pretrained(
58
+ model_name,
59
+ trust_remote_code=True,
60
+ config=model_config,
61
+ # quantization_config=bnb_config,
62
+ device_map="auto" if device == "cuda" else None,
63
+ )
64
+
65
+
66
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+
68
+ query_pipeline = transformers.pipeline(
69
+ "text-generation",
70
+ model=model,
71
+ tokenizer=tokenizer,
72
+ return_full_text=True,
73
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
74
+ device_map="auto" if device == "cuda" else None,
75
+ temperature=0.3,
76
+ top_p=0.8,
77
+ top_k=50,
78
+ repetition_penalty=1.2,
79
+ max_new_tokens=128
80
+ )
81
+
82
+
83
+ llm = HuggingFacePipeline(pipeline=query_pipeline)
84
+
85
+ books_db_client_retriever = RetrievalQA.from_chain_type(
86
+ llm=llm,
87
+ chain_type="stuff",
88
+ retriever=books_db_client,
89
+ verbose=True
90
+ )
91
+
92
+ # Function to retrieve answer using the RAG system
93
+ @spaces.GPU(duration=60)
94
+ def test_rag(query):
95
+ rag_query = f"You are an AI assistant with access to books knowledge.{query} Retrieve information only from the knowledge base provided. If you don't find relevant information in the knowledge base, do not respond with placeholder answers. Provide only clear and concise answers based on available knowledge."
96
+ books_retriever = books_db_client_retriever.run(rag_query)
97
+ corrected_text_match = re.search(r"Helpful Answer:(.*)", books_retriever, re.DOTALL)
98
+ if corrected_text_match:
99
+ return corrected_text_match.group(1).strip()
100
+ else:
101
+ return "No helpful answer found."
102
+
103
+ # Gradio interface
104
+ def respond(message, history):
105
+ response = test_rag(message)
106
+ return response
107
+
108
+ iface = gr.ChatInterface(
109
+ respond,
110
+ chatbot=gr.Chatbot(height=700),
111
+ textbox=gr.Textbox(placeholder="Ask me anything about the content of the PDF(s):", container=False, scale=7),
112
+ title="RAG Chatbot",
113
+ cache_examples=True,
114
+ )
115
+
116
+ iface.launch()