Spaces:
Paused
Paused
import streamlit as st | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import pipeline | |
from langchain_community.llms import HuggingFacePipeline | |
# # Initialize ChromaDB client | |
# chroma_client = chromadb.PersistentClient(path="data_db") | |
# # Define the embedding function | |
# sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2") | |
# # Get or create a collection | |
# collection = chroma_client.get_or_create_collection(name="my_collection", embedding_function=sentence_transformer_ef) | |
# Streamlit UI elements | |
st.title("ChromaDB and HuggingFace Pipeline Integration") | |
query = st.text_input("Enter your query:", value="director") | |
import csv | |
import chromadb | |
from chromadb.utils import embedding_functions | |
with open('./data.csv' , encoding="utf-8") as file: | |
lines = csv.reader(file) | |
documents = [] | |
metadatas = [] | |
ids = [] | |
id = 1 | |
for i, line in enumerate(lines): | |
if i == 0: | |
continue | |
documents.append(line[0]) | |
metadatas.append({"item_id": line[1]}) | |
ids.append(str(id)) | |
id += 1 | |
chroma_client = chromadb.PersistentClient(path="db") | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-mpnet-base-v2") | |
collection = chroma_client.get_or_create_collection(name="my_collection", embedding_function=sentence_transformer_ef) | |
collection.add( | |
documents=documents, | |
metadatas=metadatas, | |
ids=ids | |
) | |
if st.button("Search"): | |
# Query the collection | |
results = collection.query( | |
query_texts=[query], | |
n_results=1, | |
include=['documents', 'distances', 'metadatas'] | |
) | |
st.write("Query Results:") | |
st.write(results['metadatas']) | |
# Log the structure of results | |
st.write("Results Structure:") | |
st.write(results) | |
if 'documents' in results and results['documents']: | |
# Check if the structure of results['documents'] is as expected | |
if len(results['documents']) > 0 and isinstance(results['documents'][0], list) and len(results['documents'][0]) > 0: | |
context = results['documents'][0][0] | |
st.write("Context:") | |
st.write(context) | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("MBZUAI/LaMini-T5-738M") | |
model = AutoModelForSeq2SeqLM.from_pretrained("MBZUAI/LaMini-T5-738M") | |
# Create pipeline | |
pipe = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=512 | |
) | |
local_llm = HuggingFacePipeline(pipeline=pipe) | |
l = f""" | |
Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
{context} | |
Question: {query} | |
Helpful Answer: | |
""" | |
# Generate answer | |
answer = local_llm(l) | |
st.write("Answer:") | |
st.write(answer) | |
else: | |
st.write("No valid context found in the results.") | |
else: | |
st.write("No documents found for the query.") |