File size: 5,435 Bytes
cbabc9b
 
daad807
 
 
a4f4daa
e082d15
 
 
 
 
 
 
 
 
 
 
ed03664
e082d15
ed03664
 
 
e4f9e43
e082d15
9bccc31
5a15a8f
daad807
e082d15
38c8a95
a4f4daa
daad807
38c8a95
b3e27ff
e082d15
 
d7862ee
 
 
 
 
 
 
 
e082d15
 
d7862ee
 
e082d15
daad807
 
 
 
 
 
 
 
 
 
 
 
 
 
e082d15
daad807
e082d15
 
 
daad807
 
 
 
 
 
a4f4daa
daad807
9bccc31
daad807
9bccc31
 
 
 
daad807
9bccc31
daad807
 
 
 
a4f4daa
 
 
 
daad807
 
e79d37f
e63529f
e79d37f
 
 
 
 
 
 
0a1c169
02018c3
e63529f
2cebf72
e63529f
e79d37f
 
08d999a
0d47a2b
 
 
 
e082d15
 
 
 
 
 
8f48006
8ff6059
 
 
 
 
 
 
 
 
 
 
b3f3051
daad807
e082d15
9bccc31
 
 
 
 
 
 
 
 
daad807
e082d15
9bccc31
daad807
e082d15
daad807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4f4daa
 
e082d15
 
 
 
 
a4f4daa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import subprocess
import sys
import os
import uuid
import json
from pathlib import Path
# Install dependencies if not already installed
def install_packages():
    subprocess.check_call([sys.executable, "-m", "pip", "install", "openai"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain_community"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "sentence-transformers"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "python-dotenv"])
install_packages()

# Import installed modules
from huggingface_hub import login, CommitScheduler
import openai
import gradio as gr
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Get API tokens from environment variables
#hf_token = os.getenv("HUGGINGFACE_TOKEN")
openai.api_key = os.getenv("OPENAI_API_KEY")  # Ensure OPENAI_API_KEY is in your .env file

#if hf_token is None:
   # raise ValueError("Hugging Face token is missing. Please check your .env file.")

# Log in to Hugging Face
# Retrieve the Hugging Face token from environment variables
hf_token = os.getenv("hf_token")

# Check if the token is retrieved successfully
if not hf_token:
    raise ValueError("Hugging Face token is missing. Please set 'hf_token' as an environment variable.")

# Log in to Hugging Face with the retrieved token
login(hf_token)

print("Logged in to Hugging Face successfully.")

# Set up embeddings and vector store
embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
collection_name = 'report-10k-2024'

vectorstore_persisted = Chroma(
    collection_name=collection_name,
    persist_directory='./report_10kdb',
    embedding_function=embeddings
)

retriever = vectorstore_persisted.as_retriever(
    search_type='similarity',
    search_kwargs={'k': 5}
)

# Define Q&A system message
qna_system_message = """
You are an AI assistant for Finsights Grey Inc., helping automate extraction, summarization, and analysis of 10-K reports.
Your responses should be based solely on the context provided.
If an answer is not found in the context, respond with "I don't know."
"""

qna_user_message_template = """
###Context
Here are some documents that are relevant to the question.
{context}
###Question
{question}
"""

# Define the predict function
def predict(user_input, company):
    filter = "dataset/" + company + "-10-k-2023.pdf"
    relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter})

    # Create context for query
    context_list = [d.page_content for d in relevant_document_chunks]
    context_for_query = ".".join(context_list)

    # Create messages
    prompt = [
        {'role': 'system', 'content': qna_system_message},
        {'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)}
    ]

    # Get response from the LLM
    # Get response from the LLM using the updated API method
    # Get response from the LLM using the updated API method
    response = openai.completions.create(
        model='gpt-3.5-turbo',  # Specify the model you want to use
        messages=prompt,         # Pass the prompt (context and user message)
        temperature=0            # Set temperature for response variety
    )
    
    # Extract the prediction from the response
prediction = response['choices'][0]['message']['content']
#except Exception as e:
    # This will run if an exception occurs
prediction = str(e)

# Print the prediction or error
print(prediction)
  # Log inputs and outputs to a local log file
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
repo_id="RAGREPORTS-log",
        repo_type="dataset",
        folder_path=log_folder,
        path_in_repo="data",
        every=2
    )

with scheduler.lock:
        with log_file.open("a") as f:
            f.write(json.dumps(
                {
                    'user_input': user_input,
                    'retrieved_context': context_for_query,
                    'model_response': prediction
                }
            ))
            f.write("\n")

    # Return the prediction after logging
    #return prediction
def get_predict(question, company):
    # Map user selection to company name
    company_map = {
        "AWS": "aws",
        "IBM": "IBM",
        "Google": "Google",
        "Meta": "meta",
        "Microsoft": "msft"
    }
    selected_company = company_map.get(company)
    if not selected_company:
        return "Invalid company selected"

    return predict(question, selected_company)

# Set-up the Gradio UI
with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
    with gr.Row():
        company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
        question = gr.Textbox(label="Enter your question")

    submit = gr.Button("Submit")
    output = gr.Textbox(label="Output")

    submit.click(
        fn=get_predict,
        inputs=[question, company],
        outputs=output
    )

demo.queue()
demo.launch()