Spaces:
Running
Running
# [1] Core Imports (Updated Packages) | |
import gradio as gr | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_huggingface import HuggingFacePipeline | |
from langchain_community.document_loaders import UnstructuredURLLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_chroma import Chroma | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents.stuff import create_stuff_documents_chain | |
from langchain_core.prompts import ChatPromptTemplate | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import nltk | |
import validators | |
nltk.download('punkt', quiet=True) | |
# [2] Initialize Components | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100, | |
separators=["\n\n", "\n"] | |
) | |
# Updated embeddings initialization | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# [3] Model Setup | |
MODEL_NAME = "google/flan-t5-large" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
pipe = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=800, | |
temperature=0.6, | |
do_sample=True | |
) | |
# Updated pipeline wrapper | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# [4] Prompt Template | |
prompt_template = ChatPromptTemplate.from_messages([ | |
("system", "Generate a clear concise most simplest understanding language answer in about 3-5 bullet or more if you need more to explain points, using ONLY the context below.\n\nContext: {context}"), | |
("human", "{input}") | |
]) | |
# [5] Processing Function | |
def process_inputs(urls_str, question): | |
try: | |
print("\n=== New Request ===") | |
# Validate inputs | |
if not urls_str.strip() or not question.strip(): | |
print("Missing inputs") | |
return "β Please provide both URLs and a question" | |
urls = [url.strip() for url in urls_str.split(',') if url.strip()] | |
print(f"Processing {len(urls)} URLs") | |
# Validate URLs | |
for url in urls: | |
if not validators.url(url): | |
print(f"Invalid URL: {url}") | |
return f"β Invalid URL format: {url}" | |
# Load documents | |
try: | |
loader = UnstructuredURLLoader(urls=urls) | |
docs = loader.load() | |
print(f"Loaded {len(docs)} documents") | |
except Exception as e: | |
print(f"Document load failed: {str(e)}") | |
return f"β Failed to load documents: {str(e)}" | |
if not docs: | |
print("No content found") | |
return "β No content found in the provided URLs" | |
# Process documents | |
unique_content = list({doc.page_content.strip(): doc for doc in docs}.values()) | |
split_docs = text_splitter.split_documents(unique_content) | |
print(f"Split into {len(split_docs)} chunks") | |
# Create vector store | |
try: | |
vectorstore = Chroma.from_documents( | |
documents=split_docs, | |
embedding=embeddings | |
) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
print("Vector store created") | |
except Exception as e: | |
print(f"Vector store error: {str(e)}") | |
return f"β Vector store error: {str(e)}" | |
# Create chain | |
try: | |
print("Creating RAG chain") | |
rag_chain = create_retrieval_chain( | |
retriever, | |
create_stuff_documents_chain( | |
llm=llm, | |
prompt=prompt_template | |
) | |
) | |
print(f"Processing question: {question}") | |
response = rag_chain.invoke({"input": question}) | |
print("Answer generated successfully") | |
return response["answer"] | |
except Exception as e: | |
print(f"Generation error: {str(e)}") | |
return f"β Generation error: {str(e)}" | |
except Exception as e: | |
print(f"Unexpected error: {str(e)}") | |
return f"β Unexpected error: {str(e)}" | |
# [6] Gradio Interface (Fixed parameters) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# RAG Chat Interface") | |
with gr.Row(): | |
with gr.Column(): | |
url_input = gr.Textbox( | |
label="Paste URLs (comma-separated)", | |
placeholder="https://example.com, https://another-site.org\nSome websites may not work as they won't allow to fetch data from their site.\nTry other websites in that case.", | |
lines=3 | |
) | |
question_input = gr.Textbox( | |
label="Your Question", | |
placeholder="Type your question here...", | |
lines=3 | |
) | |
submit_btn = gr.Button("Get Answer", variant="primary") | |
answer_output = gr.Textbox( | |
label="Generated Answer", | |
interactive=False, | |
lines=10 # Removed autoscroll=True | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
"https://generativeai.net/, https://www.ibm.com/think/topics/generative-ai", | |
"What are the key benefits of generative AI?" | |
] | |
], | |
inputs=[url_input, question_input] | |
) | |
submit_btn.click( | |
fn=process_inputs, | |
inputs=[url_input, question_input], | |
outputs=answer_output | |
) | |
# [7] Launch | |
if __name__ == "__main__": | |
demo.launch() |