File size: 2,967 Bytes
0ca7ed3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from dotenv import load_dotenv
import os
import streamlit as st
import requests
from typing import List
import json
import socket
from urllib3.connection import HTTPConnection
from app import embed_documents, retrieve_documents
API_BASE_URL = os.environ.get("API_BASE_URL")
load_dotenv()
embeddings_model_name = "all-MiniLM-L6-v2"
persist_directory = "db"
model = "tiiuae/falcon-7b-instruct"
from constants import CHROMA_SETTINGS
import chromadb
def list_of_collections():
client = chromadb.Client(CHROMA_SETTINGS)
return (client.list_collections())
def main():
st.title("PrivateGPT App: Document Embedding and Retrieval")
# Document upload section
st.header("Document Upload")
files = st.file_uploader("Upload document", accept_multiple_files=True)
# collection_name = st.text_input("Collection Name") not working for some reason
if st.button("Embed"):
embed_documents(files, "collection_name")
# Query section
st.header("Document Retrieval")
collection_names = get_collection_names()
selected_collection = st.selectbox("Select a document", collection_names)
if selected_collection:
query = st.text_input("Query")
if st.button("Retrieve"):
retrieve_documents(query, selected_collection)
# def embed_documents(files:List[st.runtime.uploaded_file_manager.UploadedFile], collection_name:str):
# endpoint = f"{API_BASE_URL}/embed"
# files_data = [("files", file) for file in files]
# data = {"collection_name": collection_name}
# response = requests.post(endpoint, files=files_data, data=data)
# if response.status_code == 200:
# st.success("Documents embedded successfully!")
# else:
# st.error("Document embedding failed.")
# st.write(response.text)
def get_collection_names():
collections = list_of_collections()
return [collection.name for collection in collections]
# def retrieve_documents(query: str, collection_name: str):
# endpoint = f"{API_BASE_URL}/retrieve"
# data = {"query": query, "collection_name": collection_name}
# # Modify socket options for the HTTPConnection class
# HTTPConnection.default_socket_options = (
# HTTPConnection.default_socket_options + [
# (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
# (socket.SOL_TCP, socket.TCP_KEEPIDLE, 45),
# (socket.SOL_TCP, socket.TCP_KEEPINTVL, 10),
# (socket.SOL_TCP, socket.TCP_KEEPCNT, 6)
# ]
# )
# response = requests.post(endpoint, params=data)
# if response.status_code == 200:
# result = response.json()
# st.subheader("Results")
# st.text(result["results"])
# st.subheader("Documents")
# for doc in result["docs"]:
# st.text(doc)
# else:
# st.error("Failed to retrieve documents.")
# st.write(response.text)
if __name__ == "__main__":
main() |