File size: 15,423 Bytes
430f6e8
 
fe370a3
 
 
 
 
 
 
 
 
55b4c5a
 
b66e2f4
 
d069333
fe370a3
 
 
 
55b4c5a
fe370a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430f6e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe370a3
9a4c626
79e7580
 
fe370a3
b66e2f4
 
fe370a3
 
 
d069333
fe370a3
 
9a4c626
b66e2f4
42a40dd
9a4c626
fe370a3
 
 
 
9aa8f58
fe370a3
9a4c626
 
fe370a3
 
 
 
 
 
 
 
 
 
 
 
 
 
9aa8f58
 
 
 
 
fe370a3
 
eeaf024
 
fe370a3
 
 
 
 
 
 
 
 
 
 
 
 
9aa8f58
fe370a3
 
 
9aa8f58
fe370a3
 
55b4c5a
 
fe370a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d069333
 
 
 
 
 
 
 
 
 
 
7495086
 
 
 
 
 
 
 
 
d069333
 
 
 
 
 
 
 
 
 
 
5a8a35f
 
d069333
5a8a35f
 
d069333
5a8a35f
d069333
 
 
fe370a3
 
 
 
 
 
 
 
 
 
 
 
 
d069333
fe370a3
 
 
96ecc62
 
d069333
fe370a3
 
fb4fd4c
fe370a3
 
 
 
1bd65db
 
fe370a3
 
 
55b4c5a
 
8698bd5
55b4c5a
 
d069333
 
 
 
 
55b4c5a
fe370a3
9c41670
fe370a3
3186d0c
b66e2f4
679ccf5
 
 
 
 
 
 
 
3186d0c
79e7580
 
 
 
 
 
 
 
 
 
 
 
 
679ccf5
 
 
fe370a3
 
6786331
fe370a3
 
7dc5361
fe370a3
 
d069333
fe370a3
 
 
 
0e043cb
 
 
 
 
 
 
 
b280156
0e043cb
 
 
 
b280156
0e043cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b66e2f4
 
 
 
fe370a3
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
from fastapi import FastAPI, HTTPException, UploadFile, File,Request,Depends,status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, Json
from uuid import uuid4, UUID
from typing import Optional
import pymupdf
from pinecone import Pinecone, ServerlessSpec
import os
from dotenv import load_dotenv
from rag import *
from fastapi.responses import StreamingResponse
import json
from prompts import *
from typing import Literal
from models import *
from fastapi.middleware.cors import CORSMiddleware

load_dotenv()

pinecone_api_key = os.environ.get("PINECONE_API_KEY")
common_namespace = os.environ.get("COMMON_NAMESPACE")

pc = Pinecone(api_key=pinecone_api_key)

import time

index_name = os.environ.get("INDEX_NAME") # change if desired

existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]

if index_name not in existing_indexes:
    pc.create_index(
        name=index_name,
        dimension=3072,
        metric="cosine",
        spec=ServerlessSpec(cloud="aws", region="us-east-1"),
    )
    while not pc.describe_index(index_name).status["ready"]:
        time.sleep(1)

index = pc.Index(index_name)



api_keys = [os.environ.get("FASTAPI_API_KEY")]

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")  # use token authentication


def api_key_auth(api_key: str = Depends(oauth2_scheme)):
    if api_key not in api_keys:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Forbidden"
        )

app = FastAPI(dependencies=[Depends(api_key_auth)])

# FASTAPI_KEY_NAME = os.environ.get("FASTAPI_KEY_NAME")
# FASTAPI_API_KEY = os.environ.get("FASTAPI_API_KEY")


# @app.middleware("http")
# async def api_key_middleware(request: Request, call_next):
#     if request.url.path not in ["/","/docs","/openapi.json"]:
#         api_key = request.headers.get(FASTAPI_KEY_NAME)
#         if api_key != FASTAPI_API_KEY:
#             raise HTTPException(status_code=403, detail="invalid API key :/")
#     response = await call_next(request)
#     return response

class StyleWriter(BaseModel):
    style: Optional[str] = "neutral"
    tonality: Optional[str] = "formal"

models = ["gpt-4o","gpt-4o-mini","mistral-large-latest"]

class UserInput(BaseModel):
    prompt: str
    enterprise_id: str
    user_id: Optional[str] = None
    stream: Optional[bool] = False
    messages: Optional[list[dict]] = []
    style_tonality: Optional[StyleWriter] = None
    marque: Optional[str] = None
    model: Literal["gpt-4o","gpt-4o-mini","mistral-large-latest","o1-preview"] = "gpt-4o"


class EnterpriseData(BaseModel):
    name: str
    id: Optional[str] = None
    filename: Optional[str] = None



tasks = []

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/upload")
async def upload_file(file: UploadFile, enterprise_data: Json[EnterpriseData]):
    try:
        # Read the uploaded file
        contents = await file.read()

        enterprise_name = enterprise_data.name.replace(" ","_").replace("-","_").replace(".","_").replace("/","_").replace("\\","_").strip()

        if enterprise_data.filename is not None:
            filename = enterprise_data.filename
        else:
            filename = file.filename

        # Assign a new UUID if id is not provided
        if enterprise_data.id is None:
            clean_name = remove_non_standard_ascii(enterprise_name)
            enterprise_data.id = f"{clean_name}_{uuid4()}"

        # Open the file with PyMuPDF
        pdf_document = pymupdf.open(stream=contents, filetype="pdf")

        # Extract all text from the document
        text = ""
        for page in pdf_document:
            text += page.get_text()

        # Split the text into chunks
        text_chunks = get_text_chunks(text)

        # Create a vector store
        vector_store = get_vectorstore(text_chunks, filename=filename, file_type="pdf", namespace=enterprise_data.id, index=index,enterprise_name=enterprise_name)

        if vector_store:
            return {
                "file_name":filename,
                "enterprise_id": enterprise_data.id,
                "number_of_chunks": len(text_chunks),
                "filename_id":vector_store["filename_id"],
                "enterprise_name":enterprise_name
            }
        else:
            raise HTTPException(status_code=500, detail="Could not create vector store")
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")

    finally:
        await file.close()

@app.get("/documents/{enterprise_id}")
def get_documents(enterprise_id: str):
    try:
        docs_names = []
        for ids in  index.list(namespace=enterprise_id):
            for id in ids:
                name_doc = "_".join(id.split("_")[:-1])
                if name_doc not in docs_names:
                    docs_names.append(name_doc)
        return docs_names
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
@app.get("/documents/memory/{enterprise_id}/{user_id}")
def get_documents(enterprise_id: str, user_id: str):
    try:
        docs_names = []
        for ids in  index.list(prefix=f"kb_{user_id}_", namespace=enterprise_id):
            for id in ids:
                docs_names.append(id)
        return docs_names
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
@app.delete("/documents/{enterprise_id}/{filename_id}")
def delete_document(enterprise_id: str, filename_id: str):
    try:
        for ids in index.list(prefix=f"{filename_id}_", namespace=enterprise_id):
            index.delete(ids=ids, namespace=enterprise_id)
        return {"message": "Document deleted", "chunks_deleted": ids}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
@app.delete("/documents/memory/all/{enterprise_id}/{user_id}/")
def delete_document(enterprise_id: str, user_id: str):
    try:
        for ids in index.list(prefix=f"kb_{user_id}_", namespace=enterprise_id):
            index.delete(ids=ids, namespace=enterprise_id)
        return {"message": "Document deleted", "chunks_deleted": ids}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
@app.delete("/documents/memory/{enterprise_id}/{user_id}/{info_id}")
def delete_document(enterprise_id: str, user_id: str, info_id: str):
    try:        
        all_ids = []
        for ids in index.list(prefix=f"{info_id}", namespace=enterprise_id):
            # all_ids.extend(ids)
            print(ids)
            index.delete(ids=ids, namespace=enterprise_id)
        return {"message": "Document deleted", "chunks_deleted": all_ids}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
@app.delete("/documents/all/{enterprise_id}")
def delete_all_documents(enterprise_id: str):
    try:
        index.delete(namespace=enterprise_id,delete_all=True)
        return {"message": "All documents deleted"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
import async_timeout
import asyncio

GENERATION_TIMEOUT_SEC = 60

async def stream_generator(response, prompt, info_memoire):
    async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
        try:
            async for chunk in response:
                if isinstance(chunk, bytes):
                    chunk = chunk.decode('utf-8')  # Convert bytes to str if needed
                yield json.dumps({"prompt": prompt, "content": chunk, "info_memoire":info_memoire})
        except asyncio.TimeoutError:
            raise HTTPException(status_code=504, detail="Stream timed out")

    
@app.post("/generate-answer/")
def generate_answer(user_input: UserInput):
    try:
        print(user_input)
        
        prompt = user_input.prompt
        enterprise_id = user_input.enterprise_id

        template_prompt = base_template

        context = get_retreive_answer(enterprise_id, prompt, index, common_namespace, user_id=user_input.user_id)

        #final_prompt_simplified = prompt_formatting(prompt,template,context)
        infos_added_to_kb = handle_calling_add_to_knowledge_base(prompt,enterprise_id,index,getattr(user_input,"marque",""),user_id=getattr(user_input,"user_id",""))
        if infos_added_to_kb:
            prompt = prompt + "l'information a été ajoutée à la base de connaissance: " + infos_added_to_kb['item']
        else:
            infos_added_to_kb = {}

        if not context:
            context = ""

        if user_input.style_tonality is None:
            prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
            answer = generate_response_via_langchain(prompt, 
                                                     model=getattr(user_input,"model","gpt-4o"),
                                                     stream=user_input.stream,context = context , 
                                                     messages=user_input.messages,
                                                     template=template_prompt,
                                                     enterprise_name=getattr(user_input,"marque",""),
                                                     enterprise_id=enterprise_id,
                                                     index=index)
        else:
            prompt_formated = prompt_reformatting(template_prompt,
                                                  context,
                                                  prompt,
                                                  style=getattr(user_input.style_tonality,"style","neutral"),
                                                  tonality=getattr(user_input.style_tonality,"tonality","formal"),
                                                  enterprise_name=getattr(user_input,"marque",""))
            
            answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),
                                                     stream=user_input.stream,context = context , 
                                                     messages=user_input.messages,
                                                     style=getattr(user_input.style_tonality,"style","neutral"),
                                                     tonality=getattr(user_input.style_tonality,"tonality","formal"),
                                                     template=template_prompt,
                                                     enterprise_name=getattr(user_input,"marque",""),
                                                     enterprise_id=enterprise_id,
                                                     index=index)
        
        if user_input.stream:
            return StreamingResponse(stream_generator(answer,prompt_formated,infos_added_to_kb), media_type="application/json")
        
        return {
            "prompt": prompt_formated,
            "answer": answer,
            "context": context,
            "info_memoire": infos_added_to_kb
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")



async def stream_generator2(response, prompt, info_memoire):
    async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
        try:
            async for chunk in response:
                if isinstance(chunk, bytes):
                    yield chunk
        except asyncio.TimeoutError:
            raise HTTPException(status_code=504, detail="Stream timed out")
            
@app.post("/v2/generate-answer/")
def generate_answer2(user_input: UserInput):
    try:
        print(user_input)
        
        prompt = user_input.prompt
        enterprise_id = user_input.enterprise_id

        template_prompt = base_template

        context = get_retreive_answer(enterprise_id, prompt, index, common_namespace, user_id=user_input.user_id)

        #final_prompt_simplified = prompt_formatting(prompt,template,context)
        infos_added_to_kb = handle_calling_add_to_knowledge_base(prompt,enterprise_id,index,getattr(user_input,"marque",""),user_id=getattr(user_input,"user_id",""))
        if infos_added_to_kb:
            prompt = prompt + "l'information a été ajoutée à la base de connaissance: " + infos_added_to_kb['item']
        else:
            infos_added_to_kb = {}

        if not context:
            context = ""

        if user_input.style_tonality is None:
            prompt_formated = prompt_reformatting(template_prompt,context,prompt,enterprise_name=getattr(user_input,"marque",""))
            answer = generate_response_via_langchain(prompt, 
                                                     model=getattr(user_input,"model","gpt-4o"),
                                                     stream=user_input.stream,context = context , 
                                                     messages=user_input.messages,
                                                     template=template_prompt,
                                                     enterprise_name=getattr(user_input,"marque",""),
                                                     enterprise_id=enterprise_id,
                                                     index=index)
        else:
            prompt_formated = prompt_reformatting(template_prompt,
                                                  context,
                                                  prompt,
                                                  style=getattr(user_input.style_tonality,"style","neutral"),
                                                  tonality=getattr(user_input.style_tonality,"tonality","formal"),
                                                  enterprise_name=getattr(user_input,"marque",""))
            
            answer = generate_response_via_langchain(prompt,model=getattr(user_input,"model","gpt-4o"),
                                                     stream=user_input.stream,context = context , 
                                                     messages=user_input.messages,
                                                     style=getattr(user_input.style_tonality,"style","neutral"),
                                                     tonality=getattr(user_input.style_tonality,"tonality","formal"),
                                                     template=template_prompt,
                                                     enterprise_name=getattr(user_input,"marque",""),
                                                     enterprise_id=enterprise_id,
                                                     index=index)
        
        if user_input.stream:
            return StreamingResponse(stream_generator2(answer,prompt_formated,infos_added_to_kb), media_type="application/json")
        
        return {
            "prompt": prompt_formated,
            "answer": answer,
            "context": context,
            "info_memoire": infos_added_to_kb
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
    
@app.get("/models")
def get_models():
    return {"models": models}