Spaces:
Running
Running
import torch | |
from langchain.chains import RetrievalQA | |
from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Chroma | |
from langchain_huggingface import HuggingFaceEndpoint | |
# import pip | |
# def install(package): | |
# if hasattr(pip, 'main'): | |
# pip.main(['install', package]) | |
# else: | |
# pip._internal.main(['install', package]) | |
# # Temporal fix for incompatibility between langchain_huggingface and sentence-transformers<2.6 | |
# install("sentence-transformers==2.2.2") | |
# Check for GPU availability and set the appropriate device for computation. | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# DEVICE = "cpu" | |
# Global variables | |
conversation_retrieval_chain = None | |
chat_history = [] | |
llm_hub = None | |
embeddings = None | |
# Function to initialize the language model and its embeddings | |
def init_llm(): | |
global llm_hub, embeddings | |
# Set up the environment variable for HuggingFace and initialize the desired model. | |
# tokenfile = open("api_token.txt") | |
# api_token = tokenfile.readline().replace("\n","") | |
# tokenfile.close() | |
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_token | |
# repo name for the model | |
# model_id = "tiiuae/falcon-7b-instruct" | |
model_id = "microsoft/Phi-3.5-mini-instruct" | |
# model_id = "meta-llama/Llama-3.2-1B-Instruct" | |
# model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
# load the model into the HuggingFaceHub | |
llm_hub = HuggingFaceEndpoint(repo_id=model_id, temperature=0.1, max_new_tokens=600, model_kwargs={"max_length":600}) | |
llm_hub.client.api_url = 'https://api-inference.huggingface.co/models/'+model_id | |
# llm_hub.invoke('foo bar') | |
#Initialize embeddings using a pre-trained model to represent the text data. | |
embedddings_model = "sentence-transformers/multi-qa-distilbert-cos-v1" | |
# embedddings_model = "sentence-transformers/all-MiniLM-L6-v2" | |
embeddings = HuggingFaceInstructEmbeddings( | |
model_name=embedddings_model, | |
model_kwargs={"device": DEVICE} | |
) | |
# Function to process a PDF document | |
def process_document(document_path): | |
global conversation_retrieval_chain | |
# Load the document | |
loader = PyPDFLoader(document_path) | |
documents = loader.load() | |
# Split the document into chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) | |
texts = text_splitter.split_documents(documents) | |
# Create an embeddings database using Chroma from the split text chunks. | |
db = Chroma.from_documents(texts, embedding=embeddings) | |
# --> Build the QA chain, which utilizes the LLM and retriever for answering questions. | |
# By default, the vectorstore retriever uses similarity search. | |
# If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type (search_type="mmr"). | |
# You can also specify search kwargs like k to use when doing retrieval. k represent how many search results send to llm | |
retriever = db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}) | |
conversation_retrieval_chain = RetrievalQA.from_chain_type( | |
llm=llm_hub, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=False, | |
input_key = "question" | |
# chain_type_kwargs={"prompt": prompt} # if you are using prompt template, you need to uncomment this part | |
) | |
# Function to process a user prompt | |
def process_prompt(prompt, chat_history): | |
global conversation_retrieval_chain | |
# global chat_history | |
# Query the model | |
output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history}) | |
answer = output["result"] | |
# Update the chat history | |
chat_history.append((prompt, answer)) | |
# Return the model's response | |
return answer | |
# Initialize the language model | |
init_llm() | |