ziyingsk commited on
Commit
2ecb22f
·
verified ·
1 Parent(s): 4fdab16

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+ import itertools
5
+ from pinecone import Pinecone
6
+ from langchain_community.llms import HuggingFaceHub
7
+ from langchain.chains import LLMChain
8
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.prompts import PromptTemplate
11
+ from sentence_transformers import SentenceTransformer
12
+ import torch
13
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
14
+ import logging
15
+
16
+ # Set up environment, Pinecone is a database
17
+ load_dotenv() # Load document .env
18
+ cache_dir = os.getenv("CACHE_DIR") # Directory for cache
19
+ Huggingface_token = os.getenv("API_TOKEN") # Huggingface API key
20
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) # Database API key
21
+ index = pc.Index(os.getenv("Index_Name")) # Database index name
22
+
23
+ # Initialize embedding model (LLM will be saved to cache_dir if assigned)
24
+ embedding_model = "all-mpnet-base-v2" # See link https://www.sbert.net/docs/pretrained_models.html
25
+
26
+ if cache_dir:
27
+ embedding = SentenceTransformer(embedding_model, cache_folder=cache_dir)
28
+ else:
29
+ embedding = SentenceTransformer(embedding_model)
30
+
31
+ # Read the PDF files, divide them into chunks, and Embedding
32
+ def read_doc(file_path):
33
+ file_loader = PyPDFDirectoryLoader(file_path)
34
+ documents = file_loader.load()
35
+ return documents
36
+
37
+ def chunk_data(docs, chunk_size=300, chunk_overlap=50):
38
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
39
+ doc = text_splitter.split_documents(docs)
40
+ return doc
41
+
42
+ # Save embeddings to database
43
+ def chunks(iterable, batch_size=100):
44
+ """A helper function to break an iterable into chunks of size batch_size."""
45
+ it = iter(iterable)
46
+ chunk = tuple(itertools.islice(it, batch_size))
47
+ while chunk:
48
+ yield chunk
49
+ chunk = tuple(itertools.islice(it, batch_size))
50
+
51
+ # Streamlit interface start, uploading file
52
+ st.title("RAG-Anwendung (RAG Application)")
53
+ st.caption("Diese Anwendung kann Ihnen helfen, kostenlos Fragen zu PDF-Dateien zu stellen. (This application can help you ask questions about PDF files for free.)")
54
+
55
+ uploaded_file = st.file_uploader("Wählen Sie eine PDF-Datei, das Laden kann eine Weile dauern. (Choose a PDF file, loading might take a while.)", type="pdf")
56
+ if uploaded_file is not None:
57
+ # Ensure the temp directory exists and is empty
58
+ temp_dir = "tempDir"
59
+ if os.path.exists(temp_dir):
60
+ for file in os.listdir(temp_dir):
61
+ file_path = os.path.join(temp_dir, file)
62
+ if os.path.isfile(file_path):
63
+ os.remove(file_path)
64
+ elif os.path.isdir(file_path):
65
+ os.rmdir(file_path) # Only removes empty directories
66
+
67
+ os.makedirs(temp_dir, exist_ok=True)
68
+
69
+ # Save the uploaded file temporarily
70
+ temp_file_path = os.path.join(temp_dir, uploaded_file.name)
71
+ with open(temp_file_path, "wb") as f:
72
+ f.write(uploaded_file.getbuffer())
73
+ doc = read_doc(temp_dir+"/")
74
+ documents = chunk_data(docs=doc)
75
+ texts = [document.page_content for document in documents]
76
+ pdf_vectors = embedding.encode(texts)
77
+ vector_count = len(documents)
78
+ example_data_generator = map(lambda i: (f'id-{i}', pdf_vectors[i], {"text": texts[i]}), range(vector_count))
79
+ if 'ns1' in index.describe_index_stats()['namespaces']:
80
+ index.delete(delete_all=True,namespace='ns1')
81
+ for ids_vectors_chunk in chunks(example_data_generator, batch_size=100):
82
+ index.upsert(vectors=ids_vectors_chunk,namespace='ns1')
83
+
84
+ # Search query related context
85
+ sample_query = st.text_input("Stellen Sie eine Frage zu dem PDF: (Ask a question related to the PDF:)")
86
+ if st.button("Abschicken (Submit)"):
87
+ if uploaded_file is not None and sample_query:
88
+ query_vector = embedding.encode(sample_query).tolist()
89
+ query_search = index.query(vector=query_vector, top_k=5, include_metadata=True)
90
+
91
+ matched_contents = [match["metadata"]["text"] for match in query_search["matches"]]
92
+
93
+ # Rerank
94
+ rerank_model = "BAAI/bge-reranker-v2-m3"
95
+ if cache_dir:
96
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model, cache_dir=cache_dir)
97
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model, cache_dir=cache_dir)
98
+ else:
99
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model)
100
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
101
+ model.eval()
102
+
103
+ pairs = [[sample_query, content] for content in matched_contents]
104
+ with torch.no_grad():
105
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=300)
106
+ scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
107
+ matched_contents = [content for _, content in sorted(zip(scores, matched_contents), key=lambda x: x[0], reverse=True)]
108
+ matched_contents = matched_contents[0]
109
+ del model
110
+ torch.cuda.empty_cache()
111
+
112
+ # Display matched contents after reranking
113
+ st.markdown("### Möglicherweise relevante Abschnitte aus dem PDF (Potentially relevant sections from the PDF):")
114
+ st.write(matched_contents)
115
+
116
+ # Get answer
117
+ query_model = "meta-llama/Meta-Llama-3-8B-Instruct"
118
+ llm_huggingface = HuggingFaceHub(repo_id=query_model, model_kwargs={"temperature": 0.7, "max_length": 500})
119
+
120
+ prompt_template = PromptTemplate(input_variables=['query', 'context'], template="{query}, Beim Beantworten der Frage bitte mit dem Wort 'Antwort:' beginnen,unter Berücksichtigung des folgenden Kontexts: \n\n{context}")
121
+
122
+ prompt = prompt_template.format(query=sample_query, context=matched_contents)
123
+ chain = LLMChain(llm=llm_huggingface, prompt=prompt_template)
124
+ result = chain.run(query=sample_query, context=matched_contents)
125
+
126
+ # Polish answer
127
+ result = result.replace(prompt, "")
128
+ special_start = "Antwort:"
129
+ start_index = result.find(special_start)
130
+ if (start_index != -1):
131
+ result = result[start_index + len(special_start):].lstrip()
132
+ else:
133
+ result = result.lstrip()
134
+
135
+ # Display the final answer with a note about limitations
136
+ st.markdown("### Antwort (Answer):")
137
+ st.write(result)
138
+ st.markdown("**Hinweis:** Aufgrund begrenzter Rechenleistung kann das große Sprachmodell möglicherweise keine vollständige Antwort liefern. (Note: Due to limited computational power, the large language model might not be able to provide a complete response.)")