sahilnishad commited on
Commit
3bac9f4
·
verified ·
1 Parent(s): b8c8c0a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import streamlit as st
4
+ from streamlit_chat import message
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.llms import HuggingFacePipeline
9
+ from langchain.document_loaders import PDFMinerLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from constants import CHROMA_SETTINGS
12
+
13
+ st.set_page_config(layout="centered")
14
+
15
+ checkpoint = "MBZUAI/LaMini-T5-738M"
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=token)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(
19
+ checkpoint,
20
+ device_map="auto",
21
+ torch_dtype=torch.float32
22
+ )
23
+
24
+ @st.cache_resource
25
+ def data_ingestion(filepath):
26
+ loader = PDFMinerLoader(filepath)
27
+ documents = loader.load()
28
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
29
+ texts = text_splitter.split_documents(documents)
30
+
31
+ def embedding_function(text):
32
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
33
+ with torch.no_grad():
34
+ embeddings = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
35
+ return embeddings
36
+
37
+ db = Chroma.from_documents(texts, embedding_function=embedding_function, persist_directory="db")
38
+ db.persist()
39
+ db = None
40
+
41
+ @st.cache_resource
42
+ def llm_pipeline():
43
+ pipe = pipeline(
44
+ 'text2text-generation',
45
+ model=model,
46
+ tokenizer=tokenizer,
47
+ max_length=256,
48
+ do_sample=True,
49
+ temperature=0.3,
50
+ top_p=0.95
51
+ )
52
+ local_llm = HuggingFacePipeline(pipeline=pipe)
53
+ return local_llm
54
+
55
+ @st.cache_resource
56
+ def qa_llm():
57
+ llm = llm_pipeline()
58
+ def embedding_function(text):
59
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
60
+ with torch.no_grad():
61
+ embeddings = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
62
+ return embeddings
63
+
64
+ db = Chroma(persist_directory="db", embedding_function=embedding_function)
65
+ retriever = db.as_retriever()
66
+ qa = RetrievalQA.from_chain_type(
67
+ llm=llm,
68
+ chain_type="stuff",
69
+ retriever=retriever,
70
+ return_source_documents=True
71
+ )
72
+ return qa
73
+
74
+ def process_answer(instruction):
75
+ qa = qa_llm()
76
+ generated_text = qa(instruction)
77
+ answer = generated_text['result']
78
+ return answer
79
+
80
+ def display_conversation(history):
81
+ for i in range(len(history["generated"])):
82
+ message(history["past"][i], is_user=True, key=str(i) + "_user")
83
+ message(history["generated"][i], key=str(i))
84
+
85
+ def main():
86
+ st.markdown("<h1 style='text-align: center;'>Chat with your PDF</h1>", unsafe_allow_html=True)
87
+ st.markdown("<h2 style='text-align: center;'>Upload your PDF</h2>", unsafe_allow_html=True)
88
+ uploaded_file = st.file_uploader("", type=["pdf"])
89
+
90
+ if uploaded_file is not None:
91
+ filepath = "docs/" + uploaded_file.name
92
+ with open(filepath, "wb") as temp_file:
93
+ temp_file.write(uploaded_file.read())
94
+
95
+ with st.spinner('Embeddings are creating...'):
96
+ data_ingestion(filepath)
97
+ st.success('Embeddings are created successfully!')
98
+
99
+ user_input = st.text_input("", key="input")
100
+
101
+ if "generated" not in st.session_state:
102
+ st.session_state["generated"] = ["I am ready to help you"]
103
+ if "past" not in st.session_state:
104
+ st.session_state["past"] = ["Hey there!"]
105
+
106
+ if user_input:
107
+ answer = process_answer({'query': user_input})
108
+ st.session_state["past"].append(user_input)
109
+ st.session_state["generated"].append(answer)
110
+
111
+ display_conversation(st.session_state)
112
+
113
+ if __name__ == "__main__":
114
+ main()