|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
import os |
|
import sys |
|
|
|
|
|
from gradio.themes.base import Base |
|
|
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
from langchain.document_loaders import TextLoader |
|
from langchain.prompts import PromptTemplate |
|
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import MongoDBAtlasVectorSearch |
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
|
|
|
|
from pymongo import MongoClient |
|
|
|
|
|
from typing import Dict, Any |
|
|
|
|
|
from kaggle_secrets import UserSecretsClient |
|
|
|
directory_path = "/kaggle/input/rag-dataset/RAG" |
|
sys.path.append(directory_path) |
|
print("sys.path =", sys.path) |
|
|
|
my_txts = os.listdir(directory_path) |
|
my_txts |
|
|
|
loaders = [] |
|
for my_txt in my_txts: |
|
my_txt_path = os.path.join(directory_path, my_txt) |
|
text_loader = TextLoader(my_txt_path) |
|
loaders.append(text_loader) |
|
|
|
print("len(loaders) =", len(loaders)) |
|
|
|
loaders |
|
|
|
|
|
|
|
data = [] |
|
for loader in loaders: |
|
loaded_text = loader.load() |
|
data.append(loaded_text) |
|
|
|
print("len(data) =", len(data), "\n") |
|
|
|
|
|
data[0] |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
|
docs = [] |
|
for doc in data: |
|
chunk = text_splitter.split_documents(doc) |
|
docs.append(chunk) |
|
|
|
merged_documents = [] |
|
|
|
for doc in docs: |
|
merged_documents.extend(doc) |
|
|
|
|
|
print("len(merged_documents) =", len(merged_documents)) |
|
print(merged_documents) |
|
|
|
|
|
from kaggle_secrets import UserSecretsClient |
|
user_secrets = UserSecretsClient() |
|
|
|
MONGO_URI = user_secrets.get_secret("MONGO_URI") |
|
cluster = MongoClient(MONGO_URI) |
|
|
|
|
|
DB_NAME = "files" |
|
COLLECTION_NAME = "files_collection" |
|
|
|
|
|
MONGODB_COLLECTION = cluster[DB_NAME][COLLECTION_NAME] |
|
vector_search_index = "vector_index" |
|
|
|
from kaggle_secrets import UserSecretsClient |
|
user_secrets = UserSecretsClient() |
|
HF_TOKEN = user_secrets.get_secret("hugging_face") |
|
embedding_model = HuggingFaceInferenceAPIEmbeddings( |
|
api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vector_search = MongoDBAtlasVectorSearch.from_connection_string( |
|
connection_string=MONGO_URI, |
|
namespace=f"{DB_NAME}.{COLLECTION_NAME}", |
|
embedding=embedding_model, |
|
index_name=vector_search_index, |
|
) |
|
|
|
query = "why EfficientNetB0?" |
|
results = vector_search.similarity_search(query=query, k=25) |
|
|
|
print("\n") |
|
print(results) |
|
|
|
|
|
k = 10 |
|
|
|
|
|
score_threshold = 0.80 |
|
|
|
|
|
retriever_1 = vector_search.as_retriever( |
|
search_type = "similarity", |
|
search_kwargs = {"k": k, "score_threshold": score_threshold} |
|
) |
|
|
|
|
|
|
|
|
|
hf_client = InferenceClient(api_key=HF_TOKEN) |
|
|
|
|
|
prompt = PromptTemplate.from_template( |
|
"""Use the following pieces of context to answer the question at the end. |
|
|
|
START OF CONTEXT: |
|
{context} |
|
END OF CONTEXT: |
|
|
|
START OF QUESTION: |
|
{question} |
|
END OF QUESTION: |
|
|
|
If you do not know the answer, just say that you do not know. |
|
NEVER assume things. |
|
""" |
|
) |
|
|
|
def format_docs(docs): |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
def generate_response(input_dict: Dict[str, Any]) -> str: |
|
formatted_prompt = prompt.format(**input_dict) |
|
|
|
|
|
|
|
response = hf_client.chat.completions.create( |
|
model="Qwen/Qwen2.5-1.5B-Instruct", |
|
messages=[{ |
|
"role": "system", |
|
"content": formatted_prompt |
|
},{ |
|
"role": "user", |
|
"content": input_dict["question"] |
|
}], |
|
max_tokens=1000, |
|
temperature=0.2, |
|
) |
|
|
|
return response.choices[0].message.content |
|
|
|
rag_chain = ( |
|
{ |
|
"context": retriever_1 | RunnableLambda(format_docs), |
|
"question": RunnablePassthrough() |
|
} |
|
| RunnableLambda(generate_response) |
|
) |
|
|
|
|
|
query = "what is scaling?" |
|
answer = rag_chain.invoke(query) |
|
|
|
print("\nQuestion:", query) |
|
print("Answer:", answer) |
|
|
|
|
|
documents = retriever_1.invoke(query) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = "How the GUI was implemented?" |
|
answer = rag_chain.invoke(query) |
|
|
|
print("\nQuestion:", query) |
|
print("Answer:", answer) |
|
|
|
|
|
documents = retriever_1.invoke(query) |
|
|
|
print("\nSource documents:") |
|
from pprint import pprint |
|
pprint(results) |
|
|
|
query = "How the GUI was implemented?" |
|
answer = rag_chain.invoke(query) |
|
|
|
print("\nQuestion:", query) |
|
print("Answer:", answer) |
|
|
|
|
|
documents = retriever_1.invoke(query) |
|
formatted_docs = format_docs(documents) |
|
print("\nSource Documents:\n", formatted_docs) |