File size: 4,066 Bytes
114ce4a
 
 
 
 
 
 
127f3c4
2da0c1f
127f3c4
 
 
 
 
2da0c1f
127f3c4
fe4c7e0
2da0c1f
114ce4a
 
127f3c4
114ce4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe4c7e0
114ce4a
fe4c7e0
114ce4a
 
 
 
 
 
 
25c6d42
 
fdb4410
114ce4a
127f3c4
114ce4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe4c7e0
114ce4a
 
 
fe4c7e0
114ce4a
 
 
 
 
 
 
fe4c7e0
114ce4a
fe4c7e0
114ce4a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEndpoint
# import pip

# def install(package):
#     if hasattr(pip, 'main'):
#         pip.main(['install', package])
#     else:
#         pip._internal.main(['install', package])

# # Temporal fix for incompatibility between langchain_huggingface and sentence-transformers<2.6
# install("sentence-transformers==2.2.2")

# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# DEVICE = "cpu"

# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None

# Function to initialize the language model and its embeddings
def init_llm():
    global llm_hub, embeddings
    # Set up the environment variable for HuggingFace and initialize the desired model.
    # tokenfile = open("api_token.txt")
    # api_token = tokenfile.readline().replace("\n","")
    # tokenfile.close()
    # os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_token

    # repo name for the model
    # model_id = "tiiuae/falcon-7b-instruct"
    model_id = "microsoft/Phi-3.5-mini-instruct"
    # model_id = "meta-llama/Llama-3.2-1B-Instruct"
    # model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
    
    # load the model into the HuggingFaceHub
    llm_hub = HuggingFaceEndpoint(repo_id=model_id, temperature=0.1, max_new_tokens=600, model_kwargs={"max_length":600})
    llm_hub.client.api_url = 'https://api-inference.huggingface.co/models/'+model_id
    # llm_hub.invoke('foo bar')

    #Initialize embeddings using a pre-trained model to represent the text data.
    embedddings_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
    # embedddings_model = "sentence-transformers/all-MiniLM-L6-v2"
    
    embeddings = HuggingFaceInstructEmbeddings(
        model_name=embedddings_model,
        model_kwargs={"device": DEVICE}
    )


# Function to process a PDF document
def process_document(document_path):
    global conversation_retrieval_chain

    # Load the document
    loader = PyPDFLoader(document_path)
    documents = loader.load()
    
    # Split the document into chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
    texts = text_splitter.split_documents(documents)
    
    # Create an embeddings database using Chroma from the split text chunks.
    db = Chroma.from_documents(texts, embedding=embeddings)


    # --> Build the QA chain, which utilizes the LLM and retriever for answering questions. 
    # By default, the vectorstore retriever uses similarity search. 
    # If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type (search_type="mmr").
    # You can also specify search kwargs like k to use when doing retrieval. k represent how many search results send to llm
    retriever = db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25})
    conversation_retrieval_chain = RetrievalQA.from_chain_type(
        llm=llm_hub,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=False,
        input_key = "question"
     #   chain_type_kwargs={"prompt": prompt} # if you are using prompt template, you need to uncomment this part
    )


# Function to process a user prompt
def process_prompt(prompt, chat_history):
    global conversation_retrieval_chain
    # global chat_history
    
    # Query the model
    output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
    answer = output["result"]
    
    # Update the chat history
    chat_history.append((prompt, answer))
    
    # Return the model's response
    return answer

# Initialize the language model
init_llm()