rasyosef commited on
Commit
24445e1
β€’
1 Parent(s): 82b68a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from langchain.text_splitter import CharacterTextSplitter
4
+ from langchain.document_loaders import UnstructuredFileLoader
5
+ from langchain.vectorstores.faiss import FAISS
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.prompts.prompt import PromptTemplate
10
+ from langchain.vectorstores.base import VectorStoreRetriever
11
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
12
+
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
+ import torch
15
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
16
+
17
+ # Prompt template
18
+ template = """Instruction:
19
+ You are an AI assistant for answering questions about the provided context.
20
+ You are given the following extracted parts of a long document and a question. Provide a detailed answer.
21
+ If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
22
+ =======
23
+ {context}
24
+ =======
25
+ Chat History:
26
+
27
+ {question}
28
+ Output:"""
29
+
30
+ QA_PROMPT = PromptTemplate(
31
+ template=template,
32
+ input_variables=["question", "context"]
33
+ )
34
+
35
+ # Returns a faiss vector store given a txt file
36
+ def prepare_vector_store(filename):
37
+ # Load data
38
+ loader = UnstructuredFileLoader(filename)
39
+ raw_documents = loader.load()
40
+ print(raw_documents[:1000])
41
+
42
+ # Split the text
43
+ text_splitter = CharacterTextSplitter(
44
+ separator="\n\n",
45
+ chunk_size=400,
46
+ chunk_overlap=100,
47
+ length_function=len
48
+ )
49
+
50
+ documents = text_splitter.split_documents(raw_documents)
51
+ print(documents[:3])
52
+
53
+ # Creating a vectorstore
54
+ embeddings = HuggingFaceEmbeddings()
55
+ vectorstore = FAISS.from_documents(documents, embeddings)
56
+ print(embeddings, vectorstore)
57
+
58
+ return vectorstore
59
+
60
+ # Load Phi-2 model from hugging face hub
61
+ model_id = "microsoft/phi-2"
62
+
63
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
64
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
65
+ phi2 = pipeline("text-generation", tokenizer=tokenizer, model=model, max_new_tokens=128, device_map="auto") # GPU
66
+
67
+ phi2.tokenizer.pad_token_id = phi2.model.config.eos_token_id
68
+ hf_model = HuggingFacePipeline(pipeline=phi2)
69
+
70
+ # Retrieveal QA chian
71
+ def get_retrieval_qa_chain(filename):
72
+ llm = hf_model
73
+ retriever = VectorStoreRetriever(
74
+ vectorstore=prepare_vector_store(filename)
75
+ )
76
+ model = RetrievalQA.from_chain_type(
77
+ llm=llm,
78
+ retriever=retriever,
79
+ chain_type_kwargs={"prompt": QA_PROMPT, "verbose": True},
80
+ verbose=True,
81
+ )
82
+ print(filename)
83
+ return model
84
+
85
+ # Question Answering Chain
86
+ qa_chain = get_retrieval_qa_chain(filename="Oppenheimer-movie-wiki.txt")
87
+
88
+ # Generates response using the question answering chain defined earlier
89
+ def generate(question, chat_history):
90
+ query = ""
91
+ for req, res in chat_history:
92
+ query += f"User: {req}\n"
93
+ query += f"Assistant: {res}\n"
94
+ query += f"User: {question}"
95
+
96
+ result = qa_chain.invoke({"query": query})
97
+ response = result["result"].strip()
98
+ response = response.split("\n\n")[0].strip()
99
+
100
+ if "User:" in response:
101
+ response = response.split("User:")[0].strip()
102
+ if "INPUT:" in response:
103
+ response = response.split("INPUT:")[0].strip()
104
+ if "Assistant:" in response:
105
+ response = response.split("Assistant:")[1].strip()
106
+
107
+ chat_history.append((question, response))
108
+
109
+ return "", chat_history
110
+
111
+ # replaces the retreiver in the question answering chain whenever a new file is uploaded
112
+ def upload_file(qa_chain):
113
+ def uploader(file):
114
+ print(file)
115
+ qa_chain.retriever = VectorStoreRetriever(
116
+ vectorstore=prepare_vector_store(file)
117
+ )
118
+ return file
119
+ return uploader
120
+
121
+ with gr.Blocks() as demo:
122
+ gr.Markdown("""
123
+ # RAG-Phi-2 Chatbot demo
124
+ ### This chatbot uses the Phi-2 language model and retrieval augmented generation to allow you to add domain-specific knowledge by uploading a txt file.
125
+ """)
126
+
127
+ file_output = gr.File(label="txt file")
128
+ upload_button = gr.UploadButton(
129
+ label="Click to upload a txt file",
130
+ file_types=["text"],
131
+ file_count="single"
132
+ )
133
+ upload_button.upload(upload_file(qa_chain), upload_button, file_output)
134
+
135
+ gr.Markdown("""
136
+ ### Upload a txt file that contains the text data that you would like to augment the model with.
137
+ If you don't have one, there is a default text data already loaded, the new Oppenheimer movie's wikipedia page.
138
+ """)
139
+
140
+ chatbot = gr.Chatbot(label="RAG Phi-2 Chatbot")
141
+ msg = gr.Textbox(label="Message", placeholder="Enter text here")
142
+
143
+ clear = gr.ClearButton([msg, chatbot])
144
+ msg.submit(fn=generate, inputs=[msg, chatbot], outputs=[msg, chatbot])
145
+
146
+ demo.launch()