File size: 2,967 Bytes
0ca7ed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from dotenv import load_dotenv
import os
import streamlit as st
import requests
from typing import List
import json
import socket
from urllib3.connection import HTTPConnection
from app import embed_documents, retrieve_documents

API_BASE_URL = os.environ.get("API_BASE_URL")

load_dotenv()

embeddings_model_name = "all-MiniLM-L6-v2"
persist_directory = "db"
model = "tiiuae/falcon-7b-instruct"

from constants import CHROMA_SETTINGS
import chromadb

def list_of_collections():
    client = chromadb.Client(CHROMA_SETTINGS)
    return (client.list_collections())
    
def main():
    st.title("PrivateGPT App: Document Embedding and Retrieval")
    
    # Document upload section
    st.header("Document Upload")
    files = st.file_uploader("Upload document", accept_multiple_files=True)
    # collection_name = st.text_input("Collection Name") not working for some reason
    if st.button("Embed"):
        embed_documents(files, "collection_name")
    
    # Query section
    st.header("Document Retrieval")
    collection_names = get_collection_names()
    selected_collection = st.selectbox("Select a document", collection_names)
    if selected_collection:
        query = st.text_input("Query")
        if st.button("Retrieve"):
            retrieve_documents(query, selected_collection)

# def embed_documents(files:List[st.runtime.uploaded_file_manager.UploadedFile], collection_name:str):
#     endpoint = f"{API_BASE_URL}/embed"
#     files_data = [("files", file) for file in files]
#     data = {"collection_name": collection_name}

#     response = requests.post(endpoint, files=files_data, data=data)
#     if response.status_code == 200:
#         st.success("Documents embedded successfully!")
#     else:
#         st.error("Document embedding failed.")
#         st.write(response.text)


def get_collection_names():

    collections = list_of_collections()
    return [collection.name for collection in collections]



# def retrieve_documents(query: str, collection_name: str):
#     endpoint = f"{API_BASE_URL}/retrieve"
#     data = {"query": query, "collection_name": collection_name}

#     # Modify socket options for the HTTPConnection class
#     HTTPConnection.default_socket_options = (
#         HTTPConnection.default_socket_options + [
#             (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1),
#             (socket.SOL_TCP, socket.TCP_KEEPIDLE, 45),
#             (socket.SOL_TCP, socket.TCP_KEEPINTVL, 10),
#             (socket.SOL_TCP, socket.TCP_KEEPCNT, 6)
#         ]
#     )
    
#     response = requests.post(endpoint, params=data)
#     if response.status_code == 200:
#         result = response.json()
#         st.subheader("Results")
#         st.text(result["results"])
        
#         st.subheader("Documents")
#         for doc in result["docs"]:
#             st.text(doc)
#     else:
#         st.error("Failed to retrieve documents.")
#         st.write(response.text)


if __name__ == "__main__":
    main()