RAG-RUBIK / app.py
Pavan2k4's picture
Update app.py
521c000 verified
raw
history blame
3.76 kB
import os
import gradio as gr
from langchain.chains import RetrievalQA
from langchain_pinecone import Pinecone
from langchain_openai import ChatOpenAI
from langchain_community.llms import HuggingFacePipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
#from dotenv import load_dotenv
import torch
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline, AutoTokenizer
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))
# Initialize Embedding Model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Pinecone Retriever
api_key = os.getenv("PINCE_CONE_LIGHT")
if api_key is None:
raise ValueError("Pinecone API key missing.")
else:
pc = Pinecone(pinecone_api_key=api_key, embedding=embedding_model, index_name='rag-rubic', namespace='vectors_lightmodel')
retriever = pc.as_retriever()
# LLM Options
llm_options = {
"OpenAI": "gpt-4o-mini",
"Microsoft-Phi": "microsoft/Phi-3.5-mini-instruct",
"DeepSeek-R1": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"Intel-tinybert": "Intel/dynamic_tinybert"
}
def load_llm(name, model_name):
"""Loads the selected LLM model only when needed."""
if name == "OpenAI":
openai_api_key = os.getenv("OPEN_AI_KEY")
return ChatOpenAI(model='gpt-4o-mini', openai_api_key=openai_api_key)
if "Phi" in name or "DeepSeek" in name:
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=4096, eos_token_id=tokenizer.eos_token_id, return_full_text=False,
do_sample=False, num_return_sequences=1, max_new_tokens=50, temperature=0.1)
elif "tinybert" in name:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipe = pipeline("feature-extraction", model=model, tokenizer=tokenizer, truncation=True, padding=True, max_length=512)
else:
return None
return HuggingFacePipeline(pipeline=pipe)
# Initialize default LLM
selected_llm = list(llm_options.keys())[0]
llm = load_llm(selected_llm, llm_options[selected_llm])
# Create QA Retrieval Chain
qa = RetrievalQA.from_llm(llm=llm, retriever=retriever)
# Chatbot function
def chatbot(selected_llm, user_input, chat_history):
global llm
if hasattr(llm, "pipeline"): # Ensure llm has a pipeline
current_model = llm.pipeline.model.name_or_path # Get the model name
else:
current_model = None # Handle cases where llm is not initialized
if selected_llm != current_model:
llm = load_llm(selected_llm, llm_options[selected_llm])
response = qa.invoke({"query": user_input})
answer = response.get("result", "No response received.")
chat_history.append(("πŸ§‘β€πŸ’» You", user_input))
chat_history.append(("πŸ€– Bot", answer))
return chat_history, ""
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– RAG-Powered Chatbot")
llm_selector = gr.Dropdown(choices=list(llm_options.keys()), value=selected_llm, label="Choose an LLM")
chat_history = gr.State([])
chatbot_ui = gr.Chatbot()
user_input = gr.Textbox(label="πŸ’¬ Type your message and press Enter:")
send_button = gr.Button("Send")
send_button.click(chatbot, inputs=[llm_selector, user_input, chat_history], outputs=[chatbot_ui, user_input])
user_input.submit(chatbot, inputs=[llm_selector, user_input, chat_history], outputs=[chatbot_ui, user_input])
demo.launch()