Thomas (Tom) Gardos commited on
Commit
582abb2
Β·
2 Parent(s): 48f8268 7f989d6

Merge pull request #27 from DL4DS/code_restructure

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .chainlit/translations/en-US.json +0 -231
  2. .chainlit/translations/pt-BR.json +0 -155
  3. .gitignore +9 -1
  4. Dockerfile.dev +6 -2
  5. README.md +71 -22
  6. {.chainlit β†’ code/.chainlit}/config.toml +16 -22
  7. code/__init__.py +1 -0
  8. chainlit.md β†’ code/chainlit.md +0 -0
  9. code/main.py +87 -39
  10. code/modules/chat/__init__.py +0 -0
  11. code/modules/{chat_model_loader.py β†’ chat/chat_model_loader.py} +2 -3
  12. code/modules/chat/helpers.py +104 -0
  13. code/modules/{llm_tutor.py β†’ chat/llm_tutor.py} +94 -60
  14. code/modules/chat_processor/__init__.py +0 -0
  15. code/modules/chat_processor/base.py +12 -0
  16. code/modules/chat_processor/chat_processor.py +29 -0
  17. code/modules/chat_processor/literal_ai.py +37 -0
  18. code/modules/config/__init__.py +0 -0
  19. code/{config.yml β†’ modules/config/config.yml} +27 -8
  20. code/modules/{constants.py β†’ config/constants.py} +2 -1
  21. code/modules/dataloader/__init__.py +0 -0
  22. code/modules/{data_loader.py β†’ dataloader/data_loader.py} +184 -117
  23. code/modules/dataloader/helpers.py +108 -0
  24. code/modules/{helpers.py β†’ dataloader/webpage_crawler.py} +3 -225
  25. code/modules/retriever/__init__.py +0 -0
  26. code/modules/retriever/base.py +12 -0
  27. code/modules/retriever/chroma_retriever.py +24 -0
  28. code/modules/retriever/colbert_retriever.py +10 -0
  29. code/modules/retriever/faiss_retriever.py +23 -0
  30. code/modules/retriever/helpers.py +39 -0
  31. code/modules/retriever/raptor_retriever.py +16 -0
  32. code/modules/retriever/retriever.py +26 -0
  33. code/modules/vector_db.py +0 -226
  34. code/modules/vectorstore/__init__.py +0 -0
  35. code/modules/vectorstore/base.py +33 -0
  36. code/modules/vectorstore/chroma.py +41 -0
  37. code/modules/vectorstore/colbert.py +39 -0
  38. code/modules/{embedding_model_loader.py β†’ vectorstore/embedding_model_loader.py} +5 -8
  39. code/modules/vectorstore/faiss.py +45 -0
  40. code/modules/vectorstore/helpers.py +0 -0
  41. code/modules/vectorstore/raptor.py +438 -0
  42. code/modules/vectorstore/store_manager.py +163 -0
  43. code/modules/vectorstore/vectorstore.py +57 -0
  44. code/public/acastusphoton-svgrepo-com.svg +2 -0
  45. code/public/adv-screen-recorder-svgrepo-com.svg +2 -0
  46. code/public/alarmy-svgrepo-com.svg +2 -0
  47. public/logo_dark.png β†’ code/public/avatars/ai-tutor.png +0 -0
  48. code/public/calendar-samsung-17-svgrepo-com.svg +36 -0
  49. public/logo_light.png β†’ code/public/logo_dark.png +0 -0
  50. code/public/logo_light.png +0 -0
.chainlit/translations/en-US.json DELETED
@@ -1,231 +0,0 @@
1
- {
2
- "components": {
3
- "atoms": {
4
- "buttons": {
5
- "userButton": {
6
- "menu": {
7
- "settings": "Settings",
8
- "settingsKey": "S",
9
- "APIKeys": "API Keys",
10
- "logout": "Logout"
11
- }
12
- }
13
- }
14
- },
15
- "molecules": {
16
- "newChatButton": {
17
- "newChat": "New Chat"
18
- },
19
- "tasklist": {
20
- "TaskList": {
21
- "title": "\ud83d\uddd2\ufe0f Task List",
22
- "loading": "Loading...",
23
- "error": "An error occured"
24
- }
25
- },
26
- "attachments": {
27
- "cancelUpload": "Cancel upload",
28
- "removeAttachment": "Remove attachment"
29
- },
30
- "newChatDialog": {
31
- "createNewChat": "Create new chat?",
32
- "clearChat": "This will clear the current messages and start a new chat.",
33
- "cancel": "Cancel",
34
- "confirm": "Confirm"
35
- },
36
- "settingsModal": {
37
- "settings": "Settings",
38
- "expandMessages": "Expand Messages",
39
- "hideChainOfThought": "Hide Chain of Thought",
40
- "darkMode": "Dark Mode"
41
- },
42
- "detailsButton": {
43
- "using": "Using",
44
- "running": "Running",
45
- "took_one": "Took {{count}} step",
46
- "took_other": "Took {{count}} steps"
47
- },
48
- "auth": {
49
- "authLogin": {
50
- "title": "Login to access the app.",
51
- "form": {
52
- "email": "Email address",
53
- "password": "Password",
54
- "noAccount": "Don't have an account?",
55
- "alreadyHaveAccount": "Already have an account?",
56
- "signup": "Sign Up",
57
- "signin": "Sign In",
58
- "or": "OR",
59
- "continue": "Continue",
60
- "forgotPassword": "Forgot password?",
61
- "passwordMustContain": "Your password must contain:",
62
- "emailRequired": "email is a required field",
63
- "passwordRequired": "password is a required field"
64
- },
65
- "error": {
66
- "default": "Unable to sign in.",
67
- "signin": "Try signing in with a different account.",
68
- "oauthsignin": "Try signing in with a different account.",
69
- "redirect_uri_mismatch": "The redirect URI is not matching the oauth app configuration.",
70
- "oauthcallbackerror": "Try signing in with a different account.",
71
- "oauthcreateaccount": "Try signing in with a different account.",
72
- "emailcreateaccount": "Try signing in with a different account.",
73
- "callback": "Try signing in with a different account.",
74
- "oauthaccountnotlinked": "To confirm your identity, sign in with the same account you used originally.",
75
- "emailsignin": "The e-mail could not be sent.",
76
- "emailverify": "Please verify your email, a new email has been sent.",
77
- "credentialssignin": "Sign in failed. Check the details you provided are correct.",
78
- "sessionrequired": "Please sign in to access this page."
79
- }
80
- },
81
- "authVerifyEmail": {
82
- "almostThere": "You're almost there! We've sent an email to ",
83
- "verifyEmailLink": "Please click on the link in that email to complete your signup.",
84
- "didNotReceive": "Can't find the email?",
85
- "resendEmail": "Resend email",
86
- "goBack": "Go Back",
87
- "emailSent": "Email sent successfully.",
88
- "verifyEmail": "Verify your email address"
89
- },
90
- "providerButton": {
91
- "continue": "Continue with {{provider}}",
92
- "signup": "Sign up with {{provider}}"
93
- },
94
- "authResetPassword": {
95
- "newPasswordRequired": "New password is a required field",
96
- "passwordsMustMatch": "Passwords must match",
97
- "confirmPasswordRequired": "Confirm password is a required field",
98
- "newPassword": "New password",
99
- "confirmPassword": "Confirm password",
100
- "resetPassword": "Reset Password"
101
- },
102
- "authForgotPassword": {
103
- "email": "Email address",
104
- "emailRequired": "email is a required field",
105
- "emailSent": "Please check the email address {{email}} for instructions to reset your password.",
106
- "enterEmail": "Enter your email address and we will send you instructions to reset your password.",
107
- "resendEmail": "Resend email",
108
- "continue": "Continue",
109
- "goBack": "Go Back"
110
- }
111
- }
112
- },
113
- "organisms": {
114
- "chat": {
115
- "history": {
116
- "index": {
117
- "showHistory": "Show history",
118
- "lastInputs": "Last Inputs",
119
- "noInputs": "Such empty...",
120
- "loading": "Loading..."
121
- }
122
- },
123
- "inputBox": {
124
- "input": {
125
- "placeholder": "Type your message here..."
126
- },
127
- "speechButton": {
128
- "start": "Start recording",
129
- "stop": "Stop recording"
130
- },
131
- "SubmitButton": {
132
- "sendMessage": "Send message",
133
- "stopTask": "Stop Task"
134
- },
135
- "UploadButton": {
136
- "attachFiles": "Attach files"
137
- },
138
- "waterMark": {
139
- "text": "Built with"
140
- }
141
- },
142
- "Messages": {
143
- "index": {
144
- "running": "Running",
145
- "executedSuccessfully": "executed successfully",
146
- "failed": "failed",
147
- "feedbackUpdated": "Feedback updated",
148
- "updating": "Updating"
149
- }
150
- },
151
- "dropScreen": {
152
- "dropYourFilesHere": "Drop your files here"
153
- },
154
- "index": {
155
- "failedToUpload": "Failed to upload",
156
- "cancelledUploadOf": "Cancelled upload of",
157
- "couldNotReachServer": "Could not reach the server",
158
- "continuingChat": "Continuing previous chat"
159
- },
160
- "settings": {
161
- "settingsPanel": "Settings panel",
162
- "reset": "Reset",
163
- "cancel": "Cancel",
164
- "confirm": "Confirm"
165
- }
166
- },
167
- "threadHistory": {
168
- "sidebar": {
169
- "filters": {
170
- "FeedbackSelect": {
171
- "feedbackAll": "Feedback: All",
172
- "feedbackPositive": "Feedback: Positive",
173
- "feedbackNegative": "Feedback: Negative"
174
- },
175
- "SearchBar": {
176
- "search": "Search"
177
- }
178
- },
179
- "DeleteThreadButton": {
180
- "confirmMessage": "This will delete the thread as well as it's messages and elements.",
181
- "cancel": "Cancel",
182
- "confirm": "Confirm",
183
- "deletingChat": "Deleting chat",
184
- "chatDeleted": "Chat deleted"
185
- },
186
- "index": {
187
- "pastChats": "Past Chats"
188
- },
189
- "ThreadList": {
190
- "empty": "Empty...",
191
- "today": "Today",
192
- "yesterday": "Yesterday",
193
- "previous7days": "Previous 7 days",
194
- "previous30days": "Previous 30 days"
195
- },
196
- "TriggerButton": {
197
- "closeSidebar": "Close sidebar",
198
- "openSidebar": "Open sidebar"
199
- }
200
- },
201
- "Thread": {
202
- "backToChat": "Go back to chat",
203
- "chatCreatedOn": "This chat was created on"
204
- }
205
- },
206
- "header": {
207
- "chat": "Chat",
208
- "readme": "Readme"
209
- }
210
- }
211
- },
212
- "hooks": {
213
- "useLLMProviders": {
214
- "failedToFetchProviders": "Failed to fetch providers:"
215
- }
216
- },
217
- "pages": {
218
- "Design": {},
219
- "Env": {
220
- "savedSuccessfully": "Saved successfully",
221
- "requiredApiKeys": "Required API Keys",
222
- "requiredApiKeysInfo": "To use this app, the following API keys are required. The keys are stored on your device's local storage."
223
- },
224
- "Page": {
225
- "notPartOfProject": "You are not part of this project."
226
- },
227
- "ResumeButton": {
228
- "resumeChat": "Resume Chat"
229
- }
230
- }
231
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.chainlit/translations/pt-BR.json DELETED
@@ -1,155 +0,0 @@
1
- {
2
- "components": {
3
- "atoms": {
4
- "buttons": {
5
- "userButton": {
6
- "menu": {
7
- "settings": "Configura\u00e7\u00f5es",
8
- "settingsKey": "S",
9
- "APIKeys": "Chaves de API",
10
- "logout": "Sair"
11
- }
12
- }
13
- }
14
- },
15
- "molecules": {
16
- "newChatButton": {
17
- "newChat": "Nova Conversa"
18
- },
19
- "tasklist": {
20
- "TaskList": {
21
- "title": "\ud83d\uddd2\ufe0f Lista de Tarefas",
22
- "loading": "Carregando...",
23
- "error": "Ocorreu um erro"
24
- }
25
- },
26
- "attachments": {
27
- "cancelUpload": "Cancelar envio",
28
- "removeAttachment": "Remover anexo"
29
- },
30
- "newChatDialog": {
31
- "createNewChat": "Criar novo chat?",
32
- "clearChat": "Isso limpar\u00e1 as mensagens atuais e iniciar\u00e1 uma nova conversa.",
33
- "cancel": "Cancelar",
34
- "confirm": "Confirmar"
35
- },
36
- "settingsModal": {
37
- "expandMessages": "Expandir Mensagens",
38
- "hideChainOfThought": "Esconder Sequ\u00eancia de Pensamento",
39
- "darkMode": "Modo Escuro"
40
- }
41
- },
42
- "organisms": {
43
- "chat": {
44
- "history": {
45
- "index": {
46
- "lastInputs": "\u00daltimas Entradas",
47
- "noInputs": "Vazio...",
48
- "loading": "Carregando..."
49
- }
50
- },
51
- "inputBox": {
52
- "input": {
53
- "placeholder": "Digite sua mensagem aqui..."
54
- },
55
- "speechButton": {
56
- "start": "Iniciar grava\u00e7\u00e3o",
57
- "stop": "Parar grava\u00e7\u00e3o"
58
- },
59
- "SubmitButton": {
60
- "sendMessage": "Enviar mensagem",
61
- "stopTask": "Parar Tarefa"
62
- },
63
- "UploadButton": {
64
- "attachFiles": "Anexar arquivos"
65
- },
66
- "waterMark": {
67
- "text": "Constru\u00eddo com"
68
- }
69
- },
70
- "Messages": {
71
- "index": {
72
- "running": "Executando",
73
- "executedSuccessfully": "executado com sucesso",
74
- "failed": "falhou",
75
- "feedbackUpdated": "Feedback atualizado",
76
- "updating": "Atualizando"
77
- }
78
- },
79
- "dropScreen": {
80
- "dropYourFilesHere": "Solte seus arquivos aqui"
81
- },
82
- "index": {
83
- "failedToUpload": "Falha ao enviar",
84
- "cancelledUploadOf": "Envio cancelado de",
85
- "couldNotReachServer": "N\u00e3o foi poss\u00edvel conectar ao servidor",
86
- "continuingChat": "Continuando o chat anterior"
87
- },
88
- "settings": {
89
- "settingsPanel": "Painel de Configura\u00e7\u00f5es",
90
- "reset": "Redefinir",
91
- "cancel": "Cancelar",
92
- "confirm": "Confirmar"
93
- }
94
- },
95
- "threadHistory": {
96
- "sidebar": {
97
- "filters": {
98
- "FeedbackSelect": {
99
- "feedbackAll": "Feedback: Todos",
100
- "feedbackPositive": "Feedback: Positivo",
101
- "feedbackNegative": "Feedback: Negativo"
102
- },
103
- "SearchBar": {
104
- "search": "Buscar"
105
- }
106
- },
107
- "DeleteThreadButton": {
108
- "confirmMessage": "Isso deletar\u00e1 a conversa, assim como suas mensagens e elementos.",
109
- "cancel": "Cancelar",
110
- "confirm": "Confirmar",
111
- "deletingChat": "Deletando conversa",
112
- "chatDeleted": "Conversa deletada"
113
- },
114
- "index": {
115
- "pastChats": "Conversas Anteriores"
116
- },
117
- "ThreadList": {
118
- "empty": "Vazio..."
119
- },
120
- "TriggerButton": {
121
- "closeSidebar": "Fechar barra lateral",
122
- "openSidebar": "Abrir barra lateral"
123
- }
124
- },
125
- "Thread": {
126
- "backToChat": "Voltar para a conversa",
127
- "chatCreatedOn": "Esta conversa foi criada em"
128
- }
129
- },
130
- "header": {
131
- "chat": "Conversa",
132
- "readme": "Leia-me"
133
- }
134
- },
135
- "hooks": {
136
- "useLLMProviders": {
137
- "failedToFetchProviders": "Falha ao buscar provedores:"
138
- }
139
- },
140
- "pages": {
141
- "Design": {},
142
- "Env": {
143
- "savedSuccessfully": "Salvo com sucesso",
144
- "requiredApiKeys": "Chaves de API necess\u00e1rias",
145
- "requiredApiKeysInfo": "Para usar este aplicativo, as seguintes chaves de API s\u00e3o necess\u00e1rias. As chaves s\u00e3o armazenadas localmente em seu dispositivo."
146
- },
147
- "Page": {
148
- "notPartOfProject": "Voc\u00ea n\u00e3o faz parte deste projeto."
149
- },
150
- "ResumeButton": {
151
- "resumeChat": "Continuar Conversa"
152
- }
153
- }
154
- }
155
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -160,4 +160,12 @@ cython_debug/
160
  #.idea/
161
 
162
  # log files
163
- *.log
 
 
 
 
 
 
 
 
 
160
  #.idea/
161
 
162
  # log files
163
+ *.log
164
+
165
+ .ragatouille/*
166
+ */__pycache__/*
167
+ */.chainlit/translations/*
168
+ storage/logs/*
169
+ vectorstores/*
170
+
171
+ */.files/*
Dockerfile.dev CHANGED
@@ -10,7 +10,8 @@ RUN pip install --no-cache-dir -r /code/requirements.txt
10
 
11
  COPY . /code
12
 
13
- RUN ls -R
 
14
 
15
  # Change permissions to allow writing to the directory
16
  RUN chmod -R 777 /code
@@ -21,7 +22,10 @@ RUN mkdir /code/logs && chmod 777 /code/logs
21
  # Create a cache directory within the application's working directory
22
  RUN mkdir /.cache && chmod -R 777 /.cache
23
 
 
 
24
  # Expose the port the app runs on
25
  EXPOSE 8051
26
 
27
- CMD python code/modules/vector_db.py && chainlit run code/main.py --port 8051
 
 
10
 
11
  COPY . /code
12
 
13
+ # List the contents of the /code directory to verify files are copied correctly
14
+ RUN ls -R /code
15
 
16
  # Change permissions to allow writing to the directory
17
  RUN chmod -R 777 /code
 
22
  # Create a cache directory within the application's working directory
23
  RUN mkdir /.cache && chmod -R 777 /.cache
24
 
25
+ WORKDIR /code/code
26
+
27
  # Expose the port the app runs on
28
  EXPOSE 8051
29
 
30
+ # Default command to run the application
31
+ CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 8051"]
README.md CHANGED
@@ -1,35 +1,84 @@
1
- ---
2
- title: Dl4ds Tutor
3
- emoji: πŸƒ
4
- colorFrom: green
5
- colorTo: red
6
- sdk: docker
7
- pinned: false
8
- hf_oauth: true
9
- ---
10
 
11
- DL4DS Tutor
12
- ===========
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
- You can find an implementation of the Tutor at https://dl4ds-dl4ds-tutor.hf.space/, which is hosted on Hugging Face [here](https://huggingface.co/spaces/dl4ds/dl4ds_tutor)
17
 
18
- To run locally,
 
 
 
19
 
20
- Clone the repository from: https://github.com/DL4DS/dl4ds_tutor
 
 
21
 
22
- Put your data under the `storage/data` directory. Note: You can add urls in the urls.txt file, and other pdf files in the `storage/data` directory.
 
 
 
 
23
 
24
- To create the Vector Database, run the following command:
25
- ```python code/modules/vector_db.py```
26
- (Note: You would need to run the above when you add new data to the `storage/data` directory, or if the ``storage/data/urls.txt`` file is updated. Or you can set ``["embedding_options"]["embedd_files"]`` to True in the `code/config.yaml` file, which would embed files from the storage directory everytime you run the below chainlit command.)
 
 
 
 
27
 
28
- To run the chainlit app, run the following command:
29
- ```chainlit run code/main.py```
 
 
30
 
31
  See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ## Contributing
34
 
35
- Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
 
1
+ # DL4DS Tutor πŸƒ
 
 
 
 
 
 
 
 
2
 
3
+ Check out the configuration reference at [Hugging Face Spaces Config Reference](https://huggingface.co/docs/hub/spaces-config-reference).
 
4
 
5
+ You can find an implementation of the Tutor at [DL4DS Tutor on Hugging Face](https://dl4ds-dl4ds-tutor.hf.space/), which is hosted on Hugging Face [here](https://huggingface.co/spaces/dl4ds/dl4ds_tutor).
6
 
7
+ ## Running Locally
8
 
9
+ 1. **Clone the Repository**
10
+ ```bash
11
+ git clone https://github.com/DL4DS/dl4ds_tutor
12
+ ```
13
 
14
+ 2. **Put your data under the `storage/data` directory**
15
+ - Add URLs in the `urls.txt` file.
16
+ - Add other PDF files in the `storage/data` directory.
17
 
18
+ 3. **To test Data Loading (Optional)**
19
+ ```bash
20
+ cd code
21
+ python -m modules.dataloader.data_loader
22
+ ```
23
 
24
+ 4. **Create the Vector Database**
25
+ ```bash
26
+ cd code
27
+ python -m modules.vectorstore.store_manager
28
+ ```
29
+ - Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
30
+ - Alternatively, you can set `["vectorstore"]["embedd_files"]` to `True` in the `code/modules/config/config.yaml` file, which will embed files from the storage directory every time you run the below chainlit command.
31
 
32
+ 5. **Run the Chainlit App**
33
+ ```bash
34
+ chainlit run main.py
35
+ ```
36
 
37
  See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
38
 
39
+ ## File Structure
40
+
41
+ ```plaintext
42
+ code/
43
+ β”œβ”€β”€ modules
44
+ β”‚ β”œβ”€β”€ chat # Contains the chatbot implementation
45
+ β”‚ β”œβ”€β”€ chat_processor # Contains the implementation to process and log the conversations
46
+ β”‚ β”œβ”€β”€ config # Contains the configuration files
47
+ β”‚ β”œβ”€β”€ dataloader # Contains the implementation to load the data from the storage directory
48
+ β”‚ β”œβ”€β”€ retriever # Contains the implementation to create the retriever
49
+ β”‚ └── vectorstore # Contains the implementation to create the vector database
50
+ β”œβ”€β”€ public
51
+ β”‚ β”œβ”€β”€ logo_dark.png # Dark theme logo
52
+ β”‚ β”œβ”€β”€ logo_light.png # Light theme logo
53
+ β”‚ └── test.css # Custom CSS file
54
+ └── main.py
55
+
56
+
57
+ docs/ # Contains the documentation to the codebase and methods used
58
+
59
+ storage/
60
+ β”œβ”€β”€ data # Store files and URLs here
61
+ β”œβ”€β”€ logs # Logs directory, includes logs on vector DB creation, tutor logs, and chunks logged in JSON files
62
+ └── models # Local LLMs are loaded from here
63
+
64
+ vectorstores/ # Stores the created vector databases
65
+
66
+ .env # This needs to be created, store the API keys here
67
+ ```
68
+ - `code/modules/vectorstore/vectorstore.py`: Instantiates the `VectorStore` class to create the vector database.
69
+ - `code/modules/vectorstore/store_manager.py`: Instantiates the `VectorStoreManager:` class to manage the vector database, and all associated methods.
70
+ - `code/modules/retriever/retriever.py`: Instantiates the `Retriever` class to create the retriever.
71
+
72
+
73
+ ## Docker
74
+
75
+ The HuggingFace Space is built using the `Dockerfile` in the repository. To run it locally, use the `Dockerfile.dev` file.
76
+
77
+ ```bash
78
+ docker build --tag dev -f Dockerfile.dev .
79
+ docker run -it --rm -p 8051:8051 dev
80
+ ```
81
+
82
  ## Contributing
83
 
84
+ Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
{.chainlit β†’ code/.chainlit}/config.toml RENAMED
@@ -19,9 +19,6 @@ allow_origins = ["*"]
19
  # follow_symlink = false
20
 
21
  [features]
22
- # Show the prompt playground
23
- prompt_playground = true
24
-
25
  # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
26
  unsafe_allow_html = false
27
 
@@ -53,26 +50,20 @@ auto_tag_thread = true
53
  sample_rate = 44100
54
 
55
  [UI]
56
- # Name of the app and chatbot.
57
  name = "AI Tutor"
58
 
59
- # Show the readme while the thread is empty.
60
- show_readme_as_default = true
61
-
62
- # Description of the app and chatbot. This is used for HTML tags.
63
- # description = "AI Tutor - DS598"
64
 
65
  # Large size content are by default collapsed for a cleaner ui
66
  default_collapse_content = true
67
 
68
- # The default value for the expand messages settings.
69
- default_expand_messages = false
70
-
71
  # Hide the chain of thought details from the user in the UI.
72
- hide_cot = false
73
 
74
  # Link to your github repo. This will add a github button in the UI's header.
75
- # github = ""
76
 
77
  # Specify a CSS file that can be used to customize the user interface.
78
  # The CSS file can be served from the public directory or via an external link.
@@ -86,7 +77,7 @@ custom_css = "/public/test.css"
86
  # custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
87
 
88
  # Specify a custom meta image url.
89
- # custom_meta_image_url = "https://chainlit-cloud.s3.eu-west-3.amazonaws.com/logo/chainlit_banner.png"
90
 
91
  # Specify a custom build directory for the frontend.
92
  # This can be used to customize the frontend code.
@@ -94,18 +85,21 @@ custom_css = "/public/test.css"
94
  # custom_build = "./public/build"
95
 
96
  [UI.theme]
 
97
  #layout = "wide"
98
  #font_family = "Inter, sans-serif"
99
  # Override default MUI light theme. (Check theme.ts)
100
  [UI.theme.light]
101
- #background = "#FAFAFA"
102
- #paper = "#FFFFFF"
103
 
104
  [UI.theme.light.primary]
105
- #main = "#F80061"
106
- #dark = "#980039"
107
- #light = "#FFE7EB"
108
-
 
 
109
  # Override default MUI dark theme. (Check theme.ts)
110
  [UI.theme.dark]
111
  background = "#1C1C1C" # Slightly lighter dark background color
@@ -118,4 +112,4 @@ custom_css = "/public/test.css"
118
 
119
 
120
  [meta]
121
- generated_by = "1.1.202"
 
19
  # follow_symlink = false
20
 
21
  [features]
 
 
 
22
  # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
23
  unsafe_allow_html = false
24
 
 
50
  sample_rate = 44100
51
 
52
  [UI]
53
+ # Name of the assistant.
54
  name = "AI Tutor"
55
 
56
+ # Description of the assistant. This is used for HTML tags.
57
+ # description = ""
 
 
 
58
 
59
  # Large size content are by default collapsed for a cleaner ui
60
  default_collapse_content = true
61
 
 
 
 
62
  # Hide the chain of thought details from the user in the UI.
63
+ hide_cot = true
64
 
65
  # Link to your github repo. This will add a github button in the UI's header.
66
+ # github = "https://github.com/DL4DS/dl4ds_tutor"
67
 
68
  # Specify a CSS file that can be used to customize the user interface.
69
  # The CSS file can be served from the public directory or via an external link.
 
77
  # custom_font = "https://fonts.googleapis.com/css2?family=Inter:wght@400;500;700&display=swap"
78
 
79
  # Specify a custom meta image url.
80
+ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/f5/Boston_University_seal.svg/1200px-Boston_University_seal.svg.png"
81
 
82
  # Specify a custom build directory for the frontend.
83
  # This can be used to customize the frontend code.
 
85
  # custom_build = "./public/build"
86
 
87
  [UI.theme]
88
+ default = "light"
89
  #layout = "wide"
90
  #font_family = "Inter, sans-serif"
91
  # Override default MUI light theme. (Check theme.ts)
92
  [UI.theme.light]
93
+ background = "#FAFAFA"
94
+ paper = "#FFFFFF"
95
 
96
  [UI.theme.light.primary]
97
+ main = "#b22222" # Brighter shade of red
98
+ dark = "#8b0000" # Darker shade of the brighter red
99
+ light = "#ff6347" # Lighter shade of the brighter red
100
+ [UI.theme.light.text]
101
+ primary = "#212121"
102
+ secondary = "#616161"
103
  # Override default MUI dark theme. (Check theme.ts)
104
  [UI.theme.dark]
105
  background = "#1C1C1C" # Slightly lighter dark background color
 
112
 
113
 
114
  [meta]
115
+ generated_by = "1.1.302"
code/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .modules import *
chainlit.md β†’ code/chainlit.md RENAMED
File without changes
code/main.py CHANGED
@@ -1,9 +1,8 @@
1
- from langchain.document_loaders import PyPDFLoader, DirectoryLoader
2
- from langchain import PromptTemplate
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
6
- from langchain.llms import CTransformers
7
  import chainlit as cl
8
  from langchain_community.chat_models import ChatOpenAI
9
  from langchain_community.embeddings import OpenAIEmbeddings
@@ -11,27 +10,48 @@ import yaml
11
  import logging
12
  from dotenv import load_dotenv
13
 
14
- from modules.llm_tutor import LLMTutor
15
- from modules.constants import *
16
- from modules.helpers import get_sources
17
-
18
 
 
 
19
  logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.INFO)
 
21
 
22
  # Console Handler
23
  console_handler = logging.StreamHandler()
24
  console_handler.setLevel(logging.INFO)
25
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
26
  console_handler.setFormatter(formatter)
27
  logger.addHandler(console_handler)
28
 
29
- # File Handler
30
- log_file_path = "log_file.log" # Change this to your desired log file path
31
- file_handler = logging.FileHandler(log_file_path)
32
- file_handler.setLevel(logging.INFO)
33
- file_handler.setFormatter(formatter)
34
- logger.addHandler(file_handler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  # Adding option to select the chat profile
@@ -66,12 +86,26 @@ def rename(orig_author: str):
66
  # chainlit code
67
  @cl.on_chat_start
68
  async def start():
69
- with open("code/config.yml", "r") as f:
70
  config = yaml.safe_load(f)
71
- print(config)
72
- logger.info("Config file loaded")
73
- logger.info(f"Config: {config}")
74
- logger.info("Creating llm_tutor instance")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  chat_profile = cl.user_session.get("chat_profile")
77
  if chat_profile is not None:
@@ -93,36 +127,50 @@ async def start():
93
  llm_tutor = LLMTutor(config, logger=logger)
94
 
95
  chain = llm_tutor.qa_bot()
96
- model = config["llm_params"]["local_llm_params"]["model"]
97
- msg = cl.Message(content=f"Starting the bot {model}...")
98
- await msg.send()
99
- msg.content = opening_message
100
- await msg.update()
101
 
 
 
102
  cl.user_session.set("chain", chain)
 
 
 
 
 
 
 
103
 
104
 
105
  @cl.on_message
106
  async def main(message):
 
107
  user = cl.user_session.get("user")
108
  chain = cl.user_session.get("chain")
109
- # cb = cl.AsyncLangchainCallbackHandler(
110
- # stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
111
- # )
112
- # cb.answer_reached = True
113
- # res=await chain.acall(message, callbacks=[cb])
114
- res = await chain.acall(message.content)
115
- print(f"response: {res}")
 
 
 
 
 
 
 
 
 
116
  try:
117
  answer = res["answer"]
118
  except:
119
  answer = res["result"]
120
- print(f"answer: {answer}")
121
-
122
- logger.info(f"Question: {res['question']}")
123
- logger.info(f"History: {res['chat_history']}")
124
- logger.info(f"Answer: {answer}\n")
125
 
126
- answer_with_sources, source_elements = get_sources(res, answer)
 
127
 
128
  await cl.Message(content=answer_with_sources, elements=source_elements).send()
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain_core.prompts import PromptTemplate
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
  from langchain.chains import RetrievalQA
 
6
  import chainlit as cl
7
  from langchain_community.chat_models import ChatOpenAI
8
  from langchain_community.embeddings import OpenAIEmbeddings
 
10
  import logging
11
  from dotenv import load_dotenv
12
 
13
+ from modules.chat.llm_tutor import LLMTutor
14
+ from modules.config.constants import *
15
+ from modules.chat.helpers import get_sources
16
+ from modules.chat_processor.chat_processor import ChatProcessor
17
 
18
+ global logger
19
+ # Initialize logger
20
  logger = logging.getLogger(__name__)
21
  logger.setLevel(logging.INFO)
22
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
23
 
24
  # Console Handler
25
  console_handler = logging.StreamHandler()
26
  console_handler.setLevel(logging.INFO)
 
27
  console_handler.setFormatter(formatter)
28
  logger.addHandler(console_handler)
29
 
30
+
31
+ @cl.set_starters
32
+ async def set_starters():
33
+ return [
34
+ cl.Starter(
35
+ label="recording on CNNs?",
36
+ message="Where can I find the recording for the lecture on Transfromers?",
37
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
38
+ ),
39
+ cl.Starter(
40
+ label="where's the slides?",
41
+ message="When are the lectures? I can't find the schedule.",
42
+ icon="/public/alarmy-svgrepo-com.svg",
43
+ ),
44
+ cl.Starter(
45
+ label="Due Date?",
46
+ message="When is the final project due?",
47
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
48
+ ),
49
+ cl.Starter(
50
+ label="Explain backprop.",
51
+ message="I didnt understand the math behind backprop, could you explain it?",
52
+ icon="/public/acastusphoton-svgrepo-com.svg",
53
+ ),
54
+ ]
55
 
56
 
57
  # Adding option to select the chat profile
 
86
  # chainlit code
87
  @cl.on_chat_start
88
  async def start():
89
+ with open("modules/config/config.yml", "r") as f:
90
  config = yaml.safe_load(f)
91
+
92
+ # Ensure log directory exists
93
+ log_directory = config["log_dir"]
94
+ if not os.path.exists(log_directory):
95
+ os.makedirs(log_directory)
96
+
97
+ # File Handler
98
+ log_file_path = (
99
+ f"{log_directory}/tutor.log" # Change this to your desired log file path
100
+ )
101
+ file_handler = logging.FileHandler(log_file_path, mode="w")
102
+ file_handler.setLevel(logging.INFO)
103
+ file_handler.setFormatter(formatter)
104
+ logger.addHandler(file_handler)
105
+
106
+ logger.info("Config file loaded")
107
+ logger.info(f"Config: {config}")
108
+ logger.info("Creating llm_tutor instance")
109
 
110
  chat_profile = cl.user_session.get("chat_profile")
111
  if chat_profile is not None:
 
127
  llm_tutor = LLMTutor(config, logger=logger)
128
 
129
  chain = llm_tutor.qa_bot()
130
+ # msg = cl.Message(content=f"Starting the bot {chat_profile}...")
131
+ # await msg.send()
132
+ # msg.content = opening_message
133
+ # await msg.update()
 
134
 
135
+ tags = [chat_profile, config["vectorstore"]["db_option"]]
136
+ chat_processor = ChatProcessor(config, tags=tags)
137
  cl.user_session.set("chain", chain)
138
+ cl.user_session.set("counter", 0)
139
+ cl.user_session.set("chat_processor", chat_processor)
140
+
141
+
142
+ @cl.on_chat_end
143
+ async def on_chat_end():
144
+ await cl.Message(content="Sorry, I have to go now. Goodbye!").send()
145
 
146
 
147
  @cl.on_message
148
  async def main(message):
149
+ global logger
150
  user = cl.user_session.get("user")
151
  chain = cl.user_session.get("chain")
152
+
153
+ counter = cl.user_session.get("counter")
154
+ counter += 1
155
+ cl.user_session.set("counter", counter)
156
+
157
+ # if counter >= 3: # Ensure the counter condition is checked
158
+ # await cl.Message(content="Your credits are up!").send()
159
+ # await on_chat_end() # Call the on_chat_end function to handle the end of the chat
160
+ # return # Exit the function to stop further processing
161
+ # else:
162
+
163
+ cb = cl.AsyncLangchainCallbackHandler() # TODO: fix streaming here
164
+ cb.answer_reached = True
165
+
166
+ processor = cl.user_session.get("chat_processor")
167
+ res = await processor.rag(message.content, chain, cb)
168
  try:
169
  answer = res["answer"]
170
  except:
171
  answer = res["result"]
 
 
 
 
 
172
 
173
+ answer_with_sources, source_elements, sources_dict = get_sources(res, answer)
174
+ processor._process(message.content, answer, sources_dict)
175
 
176
  await cl.Message(content=answer_with_sources, elements=source_elements).send()
code/modules/chat/__init__.py ADDED
File without changes
code/modules/{chat_model_loader.py β†’ chat/chat_model_loader.py} RENAMED
@@ -1,8 +1,7 @@
1
  from langchain_community.chat_models import ChatOpenAI
2
- from langchain.llms import CTransformers
3
- from langchain.llms.huggingface_pipeline import HuggingFacePipeline
4
  from transformers import AutoTokenizer, TextStreamer
5
- from langchain.llms import LlamaCpp
6
  import torch
7
  import transformers
8
  import os
 
1
  from langchain_community.chat_models import ChatOpenAI
2
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
 
3
  from transformers import AutoTokenizer, TextStreamer
4
+ from langchain_community.llms import LlamaCpp
5
  import torch
6
  import transformers
7
  import os
code/modules/chat/helpers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.config.constants import *
2
+ import chainlit as cl
3
+ from langchain_core.prompts import PromptTemplate
4
+
5
+
6
+ def get_sources(res, answer):
7
+ source_elements = []
8
+ source_dict = {} # Dictionary to store URL elements
9
+
10
+ for idx, source in enumerate(res["source_documents"]):
11
+ source_metadata = source.metadata
12
+ url = source_metadata.get("source", "N/A")
13
+ score = source_metadata.get("score", "N/A")
14
+ page = source_metadata.get("page", 1)
15
+
16
+ lecture_tldr = source_metadata.get("tldr", "N/A")
17
+ lecture_recording = source_metadata.get("lecture_recording", "N/A")
18
+ suggested_readings = source_metadata.get("suggested_readings", "N/A")
19
+ date = source_metadata.get("date", "N/A")
20
+
21
+ source_type = source_metadata.get("source_type", "N/A")
22
+
23
+ url_name = f"{url}_{page}"
24
+ if url_name not in source_dict:
25
+ source_dict[url_name] = {
26
+ "text": source.page_content,
27
+ "url": url,
28
+ "score": score,
29
+ "page": page,
30
+ "lecture_tldr": lecture_tldr,
31
+ "lecture_recording": lecture_recording,
32
+ "suggested_readings": suggested_readings,
33
+ "date": date,
34
+ "source_type": source_type,
35
+ }
36
+ else:
37
+ source_dict[url_name]["text"] += f"\n\n{source.page_content}"
38
+
39
+ # First, display the answer
40
+ full_answer = "**Answer:**\n"
41
+ full_answer += answer
42
+
43
+ # Then, display the sources
44
+ full_answer += "\n\n**Sources:**\n"
45
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
46
+ full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
47
+
48
+ name = f"Source {idx + 1} Text\n"
49
+ full_answer += name
50
+ source_elements.append(
51
+ cl.Text(name=name, content=source_data["text"], display="side")
52
+ )
53
+
54
+ # Add a PDF element if the source is a PDF file
55
+ if source_data["url"].lower().endswith(".pdf"):
56
+ name = f"Source {idx + 1} PDF\n"
57
+ full_answer += name
58
+ pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
59
+ source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
60
+
61
+ full_answer += "\n**Metadata:**\n"
62
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
63
+ full_answer += f"\nSource {idx + 1} Metadata:\n"
64
+ source_elements.append(
65
+ cl.Text(
66
+ name=f"Source {idx + 1} Metadata",
67
+ content=f"Source: {source_data['url']}\n"
68
+ f"Page: {source_data['page']}\n"
69
+ f"Type: {source_data['source_type']}\n"
70
+ f"Date: {source_data['date']}\n"
71
+ f"TL;DR: {source_data['lecture_tldr']}\n"
72
+ f"Lecture Recording: {source_data['lecture_recording']}\n"
73
+ f"Suggested Readings: {source_data['suggested_readings']}\n",
74
+ display="side",
75
+ )
76
+ )
77
+
78
+ return full_answer, source_elements, source_dict
79
+
80
+
81
+ def get_prompt(config):
82
+ if config["llm_params"]["use_history"]:
83
+ if config["llm_params"]["llm_loader"] == "local_llm":
84
+ custom_prompt_template = tinyllama_prompt_template_with_history
85
+ elif config["llm_params"]["llm_loader"] == "openai":
86
+ custom_prompt_template = openai_prompt_template_with_history
87
+ # else:
88
+ # custom_prompt_template = tinyllama_prompt_template_with_history # default
89
+ prompt = PromptTemplate(
90
+ template=custom_prompt_template,
91
+ input_variables=["context", "chat_history", "question"],
92
+ )
93
+ else:
94
+ if config["llm_params"]["llm_loader"] == "local_llm":
95
+ custom_prompt_template = tinyllama_prompt_template
96
+ elif config["llm_params"]["llm_loader"] == "openai":
97
+ custom_prompt_template = openai_prompt_template
98
+ # else:
99
+ # custom_prompt_template = tinyllama_prompt_template
100
+ prompt = PromptTemplate(
101
+ template=custom_prompt_template,
102
+ input_variables=["context", "question"],
103
+ )
104
+ return prompt
code/modules/{llm_tutor.py β†’ chat/llm_tutor.py} RENAMED
@@ -1,24 +1,52 @@
1
- from langchain import PromptTemplate
2
- from langchain.embeddings import HuggingFaceEmbeddings
3
- from langchain_community.chat_models import ChatOpenAI
4
- from langchain_community.embeddings import OpenAIEmbeddings
5
- from langchain.vectorstores import FAISS
6
  from langchain.chains import RetrievalQA, ConversationalRetrievalChain
7
- from langchain.llms import CTransformers
8
- from langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemory
 
 
9
  from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
10
  import os
11
- from modules.constants import *
12
- from modules.helpers import get_prompt
13
- from modules.chat_model_loader import ChatModelLoader
14
- from modules.vector_db import VectorDB, VectorDBScore
15
- from typing import Dict, Any, Optional
 
 
 
16
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
  import inspect
18
  from langchain.chains.conversational_retrieval.base import _get_chat_history
 
 
 
 
 
 
 
19
 
20
 
21
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  async def _acall(
23
  self,
24
  inputs: Dict[str, Any],
@@ -26,13 +54,34 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
26
  ) -> Dict[str, Any]:
27
  _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
28
  question = inputs["question"]
29
- get_chat_history = self.get_chat_history or _get_chat_history
30
  chat_history_str = get_chat_history(inputs["chat_history"])
31
- print(f"chat_history_str: {chat_history_str}")
32
  if chat_history_str:
33
- callbacks = _run_manager.get_child()
34
- new_question = await self.question_generator.arun(
35
- question=question, chat_history=chat_history_str, callbacks=callbacks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
  else:
38
  new_question = question
@@ -45,6 +94,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
45
  docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
46
 
47
  output: Dict[str, Any] = {}
 
48
  if self.response_if_no_docs_found is not None and len(docs) == 0:
49
  output[self.output_key] = self.response_if_no_docs_found
50
  else:
@@ -56,31 +106,25 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
56
  # Prepare the final prompt with metadata
57
  context = "\n\n".join(
58
  [
59
- f"Document content: {doc.page_content}\nMetadata: {doc.metadata}"
60
- for doc in docs
61
  ]
62
  )
63
- final_prompt = f"""
64
- You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Use the following pieces of information to answer the user's question.
65
- If you don't know the answer, just say that you don't knowβ€”don't try to make up an answer.
66
- Use the chat history to answer the question only if it's relevant; otherwise, ignore it. The context for the answer will be under "Document context:".
67
- Use the metadata from each document to guide the user to the correct sources.
68
- The context is ordered by relevance to the question. Give more weight to the most relevant documents.
69
- Talk in a friendly and personalized manner, similar to how you would speak to a friend who needs help. Make the conversation engaging and avoid sounding repetitive or robotic.
70
-
71
- Chat History:
72
- {chat_history_str}
73
-
74
- Context:
75
- {context}
76
-
77
- Question: {new_question}
78
- AI Tutor:
79
- """
80
 
81
- new_inputs["input"] = final_prompt
82
  new_inputs["question"] = final_prompt
83
- output["final_prompt"] = final_prompt
84
 
85
  answer = await self.combine_docs_chain.arun(
86
  input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
@@ -89,8 +133,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
89
 
90
  if self.return_source_documents:
91
  output["source_documents"] = docs
92
- if self.return_generated_question:
93
- output["generated_question"] = new_question
94
  return output
95
 
96
 
@@ -98,8 +141,9 @@ class LLMTutor:
98
  def __init__(self, config, logger=None):
99
  self.config = config
100
  self.llm = self.load_llm()
101
- self.vector_db = VectorDB(config, logger=logger)
102
- if self.config["embedding_options"]["embedd_files"]:
 
103
  self.vector_db.create_database()
104
  self.vector_db.save_database()
105
 
@@ -114,24 +158,11 @@ class LLMTutor:
114
 
115
  # Retrieval QA Chain
116
  def retrieval_qa_chain(self, llm, prompt, db):
117
- if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
118
- retriever = VectorDBScore(
119
- vectorstore=db,
120
- # search_type="similarity_score_threshold",
121
- # search_kwargs={
122
- # "score_threshold": self.config["embedding_options"][
123
- # "score_threshold"
124
- # ],
125
- # "k": self.config["embedding_options"]["search_top_k"],
126
- # },
127
- )
128
- elif self.config["embedding_options"]["db_option"] == "RAGatouille":
129
- retriever = db.as_langchain_retriever(
130
- k=self.config["embedding_options"]["search_top_k"]
131
- )
132
  if self.config["llm_params"]["use_history"]:
133
- memory = ConversationSummaryBufferMemory(
134
- llm = llm,
135
  k=self.config["llm_params"]["memory_window"],
136
  memory_key="chat_history",
137
  return_messages=True,
@@ -145,6 +176,7 @@ class LLMTutor:
145
  return_source_documents=True,
146
  memory=memory,
147
  combine_docs_chain_kwargs={"prompt": prompt},
 
148
  )
149
  else:
150
  qa_chain = RetrievalQA.from_chain_type(
@@ -166,7 +198,9 @@ class LLMTutor:
166
  def qa_bot(self):
167
  db = self.vector_db.load_database()
168
  qa_prompt = self.set_custom_prompt()
169
- qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
 
 
170
 
171
  return qa
172
 
 
 
 
 
 
 
1
  from langchain.chains import RetrievalQA, ConversationalRetrievalChain
2
+ from langchain.memory import (
3
+ ConversationBufferWindowMemory,
4
+ ConversationSummaryBufferMemory,
5
+ )
6
  from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
7
  import os
8
+ from modules.config.constants import *
9
+ from modules.chat.helpers import get_prompt
10
+ from modules.chat.chat_model_loader import ChatModelLoader
11
+ from modules.vectorstore.store_manager import VectorStoreManager
12
+
13
+ from modules.retriever.retriever import Retriever
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
16
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
17
  import inspect
18
  from langchain.chains.conversational_retrieval.base import _get_chat_history
19
+ from langchain_core.messages import BaseMessage
20
+
21
+ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
22
+
23
+ from langchain_core.output_parsers import StrOutputParser
24
+ from langchain_core.prompts import ChatPromptTemplate
25
+ from langchain_community.chat_models import ChatOpenAI
26
 
27
 
28
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
+
30
+ def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
31
+ _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
32
+ buffer = ""
33
+ for dialogue_turn in chat_history:
34
+ if isinstance(dialogue_turn, BaseMessage):
35
+ role_prefix = _ROLE_MAP.get(
36
+ dialogue_turn.type, f"{dialogue_turn.type}: "
37
+ )
38
+ buffer += f"\n{role_prefix}{dialogue_turn.content}"
39
+ elif isinstance(dialogue_turn, tuple):
40
+ human = "Student: " + dialogue_turn[0]
41
+ ai = "AI Tutor: " + dialogue_turn[1]
42
+ buffer += "\n" + "\n".join([human, ai])
43
+ else:
44
+ raise ValueError(
45
+ f"Unsupported chat history format: {type(dialogue_turn)}."
46
+ f" Full chat history: {chat_history} "
47
+ )
48
+ return buffer
49
+
50
  async def _acall(
51
  self,
52
  inputs: Dict[str, Any],
 
54
  ) -> Dict[str, Any]:
55
  _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
56
  question = inputs["question"]
57
+ get_chat_history = self._get_chat_history
58
  chat_history_str = get_chat_history(inputs["chat_history"])
 
59
  if chat_history_str:
60
+ # callbacks = _run_manager.get_child()
61
+ # new_question = await self.question_generator.arun(
62
+ # question=question, chat_history=chat_history_str, callbacks=callbacks
63
+ # )
64
+ system = (
65
+ "You are someone that rephrases statements. Rephrase the student's question to add context from their chat history if relevant, ensuring it remains from the student's point of view. "
66
+ "Incorporate relevant details from the chat history to make the question clearer and more specific."
67
+ "Do not change the meaning of the original statement, and maintain the student's tone and perspective. "
68
+ "If the question is conversational and doesn't require context, do not rephrase it. "
69
+ "Example: If the student previously asked about backpropagation in the context of deep learning and now asks 'what is it', rephrase to 'What is backprogatation.'. "
70
+ "Example: Do not rephrase if the user is asking something specific like 'cool, suggest a project with transformers to use as my final project'"
71
+ "Chat history: \n{chat_history_str}\n"
72
+ "Rephrase the following question only if necessary: '{question}'"
73
+ )
74
+
75
+ prompt = ChatPromptTemplate.from_messages(
76
+ [
77
+ ("system", system),
78
+ ("human", "{question}, {chat_history_str}"),
79
+ ]
80
+ )
81
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
82
+ step_back = prompt | llm | StrOutputParser()
83
+ new_question = step_back.invoke(
84
+ {"question": question, "chat_history_str": chat_history_str}
85
  )
86
  else:
87
  new_question = question
 
94
  docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
95
 
96
  output: Dict[str, Any] = {}
97
+ output["original_question"] = question
98
  if self.response_if_no_docs_found is not None and len(docs) == 0:
99
  output[self.output_key] = self.response_if_no_docs_found
100
  else:
 
106
  # Prepare the final prompt with metadata
107
  context = "\n\n".join(
108
  [
109
+ f"Context {idx+1}: \n(Document content: {doc.page_content}\nMetadata: (source_file: {doc.metadata['source'] if 'source' in doc.metadata else 'unknown'}))"
110
+ for idx, doc in enumerate(docs)
111
  ]
112
  )
113
+ final_prompt = (
114
+ "You are an AI Tutor for the course DS598, taught by Prof. Thomas Gardos. Answer the user's question using the provided context. Only use the context if it is relevant. The context is ordered by relevance."
115
+ "If you don't know the answer, do your best without making things up. Keep the conversation flowing naturally. "
116
+ "Use chat history and context as guides but avoid repeating past responses. Provide links from the source_file metadata. Use the source context that is most relevent."
117
+ "Speak in a friendly and engaging manner, like talking to a friend. Avoid sounding repetitive or robotic.\n\n"
118
+ f"Chat History:\n{chat_history_str}\n\n"
119
+ f"Context:\n{context}\n\n"
120
+ "Answer the student's question below in a friendly, concise, and engaging manner. Use the context and history only if relevant, otherwise, engage in a free-flowing conversation.\n"
121
+ f"Student: {question}\n"
122
+ "AI Tutor:"
123
+ )
 
 
 
 
 
 
124
 
125
+ # new_inputs["input"] = final_prompt
126
  new_inputs["question"] = final_prompt
127
+ # output["final_prompt"] = final_prompt
128
 
129
  answer = await self.combine_docs_chain.arun(
130
  input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
 
133
 
134
  if self.return_source_documents:
135
  output["source_documents"] = docs
136
+ output["rephrased_question"] = new_question
 
137
  return output
138
 
139
 
 
141
  def __init__(self, config, logger=None):
142
  self.config = config
143
  self.llm = self.load_llm()
144
+ self.logger = logger
145
+ self.vector_db = VectorStoreManager(config, logger=self.logger)
146
+ if self.config["vectorstore"]["embedd_files"]:
147
  self.vector_db.create_database()
148
  self.vector_db.save_database()
149
 
 
158
 
159
  # Retrieval QA Chain
160
  def retrieval_qa_chain(self, llm, prompt, db):
161
+
162
+ retriever = Retriever(self.config)._return_retriever(db)
163
+
 
 
 
 
 
 
 
 
 
 
 
 
164
  if self.config["llm_params"]["use_history"]:
165
+ memory = ConversationBufferWindowMemory(
 
166
  k=self.config["llm_params"]["memory_window"],
167
  memory_key="chat_history",
168
  return_messages=True,
 
176
  return_source_documents=True,
177
  memory=memory,
178
  combine_docs_chain_kwargs={"prompt": prompt},
179
+ response_if_no_docs_found="No context found",
180
  )
181
  else:
182
  qa_chain = RetrievalQA.from_chain_type(
 
198
  def qa_bot(self):
199
  db = self.vector_db.load_database()
200
  qa_prompt = self.set_custom_prompt()
201
+ qa = self.retrieval_qa_chain(
202
+ self.llm, qa_prompt, db
203
+ ) # TODO: PROMPT is overwritten in CustomConversationalRetrievalChain
204
 
205
  return qa
206
 
code/modules/chat_processor/__init__.py ADDED
File without changes
code/modules/chat_processor/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Template for chat processor classes
2
+
3
+
4
+ class ChatProcessorBase:
5
+ def __init__(self, config):
6
+ self.config = config
7
+
8
+ def process(self, message):
9
+ """
10
+ Processes and Logs the message
11
+ """
12
+ raise NotImplementedError("process method not implemented")
code/modules/chat_processor/chat_processor.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.chat_processor.literal_ai import LiteralaiChatProcessor
2
+
3
+
4
+ class ChatProcessor:
5
+ def __init__(self, config, tags=None):
6
+ self.chat_processor_type = config["chat_logging"]["platform"]
7
+ self.logging = config["chat_logging"]["log_chat"]
8
+ self.tags = tags
9
+ self._init_processor()
10
+
11
+ def _init_processor(self):
12
+ if self.chat_processor_type == "literalai":
13
+ self.processor = LiteralaiChatProcessor(self.tags)
14
+ else:
15
+ raise ValueError(
16
+ f"Chat processor type {self.chat_processor_type} not supported"
17
+ )
18
+
19
+ def _process(self, user_message, assistant_message, source_dict):
20
+ if self.logging:
21
+ return self.processor.process(user_message, assistant_message, source_dict)
22
+ else:
23
+ pass
24
+
25
+ async def rag(self, user_query: str, chain, cb):
26
+ if self.logging:
27
+ return await self.processor.rag(user_query, chain, cb)
28
+ else:
29
+ return await chain.acall(user_query, callbacks=[cb])
code/modules/chat_processor/literal_ai.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from literalai import LiteralClient
2
+ import os
3
+ from .base import ChatProcessorBase
4
+
5
+
6
+ class LiteralaiChatProcessor(ChatProcessorBase):
7
+ def __init__(self, tags=None):
8
+ self.literal_client = LiteralClient(api_key=os.getenv("LITERAL_API_KEY"))
9
+ self.literal_client.reset_context()
10
+ with self.literal_client.thread(name="TEST") as thread:
11
+ self.thread_id = thread.id
12
+ self.thread = thread
13
+ if tags is not None and type(tags) == list:
14
+ self.thread.tags = tags
15
+ print(f"Thread ID: {self.thread}")
16
+
17
+ def process(self, user_message, assistant_message, source_dict):
18
+ with self.literal_client.thread(thread_id=self.thread_id) as thread:
19
+ self.literal_client.message(
20
+ content=user_message,
21
+ type="user_message",
22
+ name="User",
23
+ )
24
+ self.literal_client.message(
25
+ content=assistant_message,
26
+ type="assistant_message",
27
+ name="AI_Tutor",
28
+ )
29
+
30
+ async def rag(self, user_query: str, chain, cb):
31
+ with self.literal_client.step(
32
+ type="retrieval", name="RAG", thread_id=self.thread_id
33
+ ) as step:
34
+ step.input = {"question": user_query}
35
+ res = await chain.acall(user_query, callbacks=[cb])
36
+ step.output = res
37
+ return res
code/modules/config/__init__.py ADDED
File without changes
code/{config.yml β†’ modules/config/config.yml} RENAMED
@@ -1,13 +1,28 @@
1
- embedding_options:
 
 
 
 
2
  embedd_files: False # bool
3
- data_path: 'storage/data' # str
4
- url_file_path: 'storage/data/urls.txt' # str
5
  expand_urls: True # bool
6
- db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille]
7
- db_path : 'vectorstores' # str
8
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
9
  search_top_k : 3 # int
10
  score_threshold : 0.2 # float
 
 
 
 
 
 
 
 
 
 
 
11
  llm_params:
12
  use_history: True # bool
13
  memory_window: 3 # int
@@ -15,9 +30,13 @@ llm_params:
15
  openai_params:
16
  model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
17
  local_llm_params:
18
- model: "storage/models/llama-2-7b-chat.Q4_0.gguf"
19
- model_type: "llama"
20
- temperature: 0.2
 
 
 
 
21
  splitter_options:
22
  use_splitter: True # bool
23
  split_by_token : True # bool
 
1
+ log_dir: '../storage/logs' # str
2
+ log_chunk_dir: '../storage/logs/chunks' # str
3
+ device: 'cpu' # str [cuda, cpu]
4
+
5
+ vectorstore:
6
  embedd_files: False # bool
7
+ data_path: '../storage/data' # str
8
+ url_file_path: '../storage/data/urls.txt' # str
9
  expand_urls: True # bool
10
+ db_option : 'FAISS' # str [FAISS, Chroma, RAGatouille, RAPTOR]
11
+ db_path : '../vectorstores' # str
12
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
13
  search_top_k : 3 # int
14
  score_threshold : 0.2 # float
15
+
16
+ faiss_params: # Not used as of now
17
+ index_path: '../vectorstores/faiss.index' # str
18
+ index_type: 'Flat' # str [Flat, HNSW, IVF]
19
+ index_dimension: 384 # int
20
+ index_nlist: 100 # int
21
+ index_nprobe: 10 # int
22
+
23
+ colbert_params:
24
+ index_name: "new_idx" # str
25
+
26
  llm_params:
27
  use_history: True # bool
28
  memory_window: 3 # int
 
30
  openai_params:
31
  model: 'gpt-3.5-turbo-1106' # str [gpt-3.5-turbo-1106, gpt-4]
32
  local_llm_params:
33
+ model: 'tiny-llama'
34
+ temperature: 0.7
35
+
36
+ chat_logging:
37
+ log_chat: False # bool
38
+ platform: 'literalai'
39
+
40
  splitter_options:
41
  use_splitter: True # bool
42
  split_by_token : True # bool
code/modules/{constants.py β†’ config/constants.py} RENAMED
@@ -7,6 +7,7 @@ load_dotenv()
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
 
10
 
11
  opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
12
 
@@ -77,5 +78,5 @@ Question: {question}
77
 
78
  # Model Paths
79
 
80
- LLAMA_PATH = "storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
81
  MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
 
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
9
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
10
+ LITERAL_API_KEY = os.getenv("LITERAL_API_KEY")
11
 
12
  opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
13
 
 
78
 
79
  # Model Paths
80
 
81
+ LLAMA_PATH = "../storage/models/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
82
  MISTRAL_PATH = "storage/models/mistral-7b-v0.1.Q4_K_M.gguf"
code/modules/dataloader/__init__.py ADDED
File without changes
code/modules/{data_loader.py β†’ dataloader/data_loader.py} RENAMED
@@ -16,15 +16,12 @@ import logging
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from ragatouille import RAGPretrainedModel
18
  from langchain.chains import LLMChain
19
- from langchain.llms import OpenAI
20
  from langchain import PromptTemplate
 
 
21
 
22
- try:
23
- from modules.helpers import get_metadata
24
- except:
25
- from helpers import get_metadata
26
-
27
- logger = logging.getLogger(__name__)
28
 
29
 
30
  class PDFReader:
@@ -40,8 +37,9 @@ class PDFReader:
40
 
41
 
42
  class FileReader:
43
- def __init__(self):
44
  self.pdf_reader = PDFReader()
 
45
 
46
  def extract_text_from_pdf(self, pdf_path):
47
  text = ""
@@ -61,7 +59,7 @@ class FileReader:
61
  temp_file_path = temp_file.name
62
  return temp_file_path
63
  else:
64
- print("Failed to download PDF from URL:", pdf_url)
65
  return None
66
 
67
  def read_pdf(self, temp_file_path: str):
@@ -99,13 +97,18 @@ class FileReader:
99
  if response.status_code == 200:
100
  return [Document(page_content=response.text)]
101
  else:
102
- print("Failed to fetch .tex file from URL:", tex_url)
103
  return None
104
 
105
 
106
  class ChunkProcessor:
107
- def __init__(self, config):
108
  self.config = config
 
 
 
 
 
109
 
110
  if config["splitter_options"]["use_splitter"]:
111
  if config["splitter_options"]["split_by_token"]:
@@ -124,7 +127,7 @@ class ChunkProcessor:
124
  )
125
  else:
126
  self.splitter = None
127
- logger.info("ChunkProcessor instance created")
128
 
129
  def remove_delimiters(self, document_chunks: list):
130
  for chunk in document_chunks:
@@ -139,7 +142,6 @@ class ChunkProcessor:
139
  del document_chunks[0]
140
  for _ in range(end):
141
  document_chunks.pop()
142
- logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
143
  return document_chunks
144
 
145
  def process_chunks(
@@ -172,122 +174,187 @@ class ChunkProcessor:
172
 
173
  return document_chunks
174
 
175
- def get_chunks(self, file_reader, uploaded_files, weblinks):
176
- self.document_chunks_full = []
177
- self.parent_document_names = []
178
- self.child_document_names = []
179
- self.documents = []
180
- self.document_metadata = []
181
-
182
  addl_metadata = get_metadata(
183
  "https://dl4ds.github.io/sp2024/lectures/",
184
  "https://dl4ds.github.io/sp2024/schedule/",
185
  ) # For any additional metadata
186
 
187
- for file_index, file_path in enumerate(uploaded_files):
188
- file_name = os.path.basename(file_path)
189
- if file_name not in self.parent_document_names:
190
- file_type = file_name.split(".")[-1].lower()
191
-
192
- # try:
193
- if file_type == "pdf":
194
- documents = file_reader.read_pdf(file_path)
195
- elif file_type == "txt":
196
- documents = file_reader.read_txt(file_path)
197
- elif file_type == "docx":
198
- documents = file_reader.read_docx(file_path)
199
- elif file_type == "srt":
200
- documents = file_reader.read_srt(file_path)
201
- elif file_type == "tex":
202
- documents = file_reader.read_tex_from_url(file_path)
203
- else:
204
- logger.warning(f"Unsupported file type: {file_type}")
205
- continue
206
-
207
- for doc in documents:
208
- page_num = doc.metadata.get("page", 0)
209
- self.documents.append(doc.page_content)
210
- self.document_metadata.append(
211
- {"source": file_path, "page": page_num}
212
- )
213
- metadata = addl_metadata.get(file_path, {})
214
- self.document_metadata[-1].update(metadata)
215
-
216
- self.child_document_names.append(f"{file_name}_{page_num}")
217
-
218
- self.parent_document_names.append(file_name)
219
- if self.config["embedding_options"]["db_option"] not in [
220
- "RAGatouille"
221
- ]:
222
- document_chunks = self.process_chunks(
223
- self.documents[-1],
224
- file_type,
225
- source=file_path,
226
- page=page_num,
227
- metadata=metadata,
228
- )
229
- self.document_chunks_full.extend(document_chunks)
230
-
231
- # except Exception as e:
232
- # logger.error(f"Error processing file {file_name}: {str(e)}")
233
-
234
- self.process_weblinks(file_reader, weblinks)
235
-
236
- logger.info(
237
  f"Total document chunks extracted: {len(self.document_chunks_full)}"
238
  )
239
- return (
240
- self.document_chunks_full,
241
- self.child_document_names,
242
- self.documents,
243
- self.document_metadata,
244
- )
245
 
246
- def process_weblinks(self, file_reader, weblinks):
247
- if weblinks[0] != "":
248
- logger.info(f"Splitting weblinks: total of {len(weblinks)}")
249
-
250
- for link_index, link in enumerate(weblinks):
251
- if link not in self.parent_document_names:
252
- try:
253
- logger.info(f"\tSplitting link {link_index+1} : {link}")
254
- if "youtube" in link:
255
- documents = file_reader.read_youtube_transcript(link)
256
- else:
257
- documents = file_reader.read_html(link)
258
-
259
- for doc in documents:
260
- page_num = doc.metadata.get("page", 0)
261
- self.documents.append(doc.page_content)
262
- self.document_metadata.append(
263
- {"source": link, "page": page_num}
264
- )
265
- self.child_document_names.append(f"{link}")
266
-
267
- self.parent_document_names.append(link)
268
- if self.config["embedding_options"]["db_option"] not in [
269
- "RAGatouille"
270
- ]:
271
- document_chunks = self.process_chunks(
272
- self.documents[-1],
273
- "txt",
274
- source=link,
275
- page=0,
276
- metadata={"source_type": "webpage"},
277
- )
278
- self.document_chunks_full.extend(document_chunks)
279
- except Exception as e:
280
- logger.error(
281
- f"Error splitting link {link_index+1} : {link}: {str(e)}"
282
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
 
285
  class DataLoader:
286
- def __init__(self, config):
287
- self.file_reader = FileReader()
288
- self.chunk_processor = ChunkProcessor(config)
289
 
290
  def get_chunks(self, uploaded_files, weblinks):
291
- return self.chunk_processor.get_chunks(
292
  self.file_reader, uploaded_files, weblinks
293
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from ragatouille import RAGPretrainedModel
18
  from langchain.chains import LLMChain
19
+ from langchain_community.llms import OpenAI
20
  from langchain import PromptTemplate
21
+ import json
22
+ from concurrent.futures import ThreadPoolExecutor
23
 
24
+ from modules.dataloader.helpers import get_metadata
 
 
 
 
 
25
 
26
 
27
  class PDFReader:
 
37
 
38
 
39
  class FileReader:
40
+ def __init__(self, logger):
41
  self.pdf_reader = PDFReader()
42
+ self.logger = logger
43
 
44
  def extract_text_from_pdf(self, pdf_path):
45
  text = ""
 
59
  temp_file_path = temp_file.name
60
  return temp_file_path
61
  else:
62
+ self.logger.error(f"Failed to download PDF from URL: {pdf_url}")
63
  return None
64
 
65
  def read_pdf(self, temp_file_path: str):
 
97
  if response.status_code == 200:
98
  return [Document(page_content=response.text)]
99
  else:
100
+ self.logger.error(f"Failed to fetch .tex file from URL: {tex_url}")
101
  return None
102
 
103
 
104
  class ChunkProcessor:
105
+ def __init__(self, config, logger):
106
  self.config = config
107
+ self.logger = logger
108
+
109
+ self.document_data = {}
110
+ self.document_metadata = {}
111
+ self.document_chunks_full = []
112
 
113
  if config["splitter_options"]["use_splitter"]:
114
  if config["splitter_options"]["split_by_token"]:
 
127
  )
128
  else:
129
  self.splitter = None
130
+ self.logger.info("ChunkProcessor instance created")
131
 
132
  def remove_delimiters(self, document_chunks: list):
133
  for chunk in document_chunks:
 
142
  del document_chunks[0]
143
  for _ in range(end):
144
  document_chunks.pop()
 
145
  return document_chunks
146
 
147
  def process_chunks(
 
174
 
175
  return document_chunks
176
 
177
+ def chunk_docs(self, file_reader, uploaded_files, weblinks):
 
 
 
 
 
 
178
  addl_metadata = get_metadata(
179
  "https://dl4ds.github.io/sp2024/lectures/",
180
  "https://dl4ds.github.io/sp2024/schedule/",
181
  ) # For any additional metadata
182
 
183
+ with ThreadPoolExecutor() as executor:
184
+ executor.map(
185
+ self.process_file,
186
+ uploaded_files,
187
+ range(len(uploaded_files)),
188
+ [file_reader] * len(uploaded_files),
189
+ [addl_metadata] * len(uploaded_files),
190
+ )
191
+ executor.map(
192
+ self.process_weblink,
193
+ weblinks,
194
+ range(len(weblinks)),
195
+ [file_reader] * len(weblinks),
196
+ [addl_metadata] * len(weblinks),
197
+ )
198
+
199
+ document_names = [
200
+ f"{file_name}_{page_num}"
201
+ for file_name, pages in self.document_data.items()
202
+ for page_num in pages.keys()
203
+ ]
204
+ documents = [
205
+ page for doc in self.document_data.values() for page in doc.values()
206
+ ]
207
+ document_metadata = [
208
+ page for doc in self.document_metadata.values() for page in doc.values()
209
+ ]
210
+
211
+ self.save_document_data()
212
+
213
+ self.logger.info(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  f"Total document chunks extracted: {len(self.document_chunks_full)}"
215
  )
 
 
 
 
 
 
216
 
217
+ return self.document_chunks_full, document_names, documents, document_metadata
218
+
219
+ def process_documents(
220
+ self, documents, file_path, file_type, metadata_source, addl_metadata
221
+ ):
222
+ file_data = {}
223
+ file_metadata = {}
224
+
225
+ for doc in documents:
226
+ # if len(doc.page_content) <= 400: # better approach to filter out non-informative documents
227
+ # continue
228
+
229
+ page_num = doc.metadata.get("page", 0)
230
+ file_data[page_num] = doc.page_content
231
+ metadata = (
232
+ addl_metadata.get(file_path, {})
233
+ if metadata_source == "file"
234
+ else {"source": file_path, "page": page_num}
235
+ )
236
+ file_metadata[page_num] = metadata
237
+
238
+ if self.config["vectorstore"]["db_option"] not in ["RAGatouille"]:
239
+ document_chunks = self.process_chunks(
240
+ doc.page_content,
241
+ file_type,
242
+ source=file_path,
243
+ page=page_num,
244
+ metadata=metadata,
245
+ )
246
+ self.document_chunks_full.extend(document_chunks)
247
+
248
+ self.document_data[file_path] = file_data
249
+ self.document_metadata[file_path] = file_metadata
250
+
251
+ def process_file(self, file_path, file_index, file_reader, addl_metadata):
252
+ file_name = os.path.basename(file_path)
253
+ if file_name in self.document_data:
254
+ return
255
+
256
+ file_type = file_name.split(".")[-1].lower()
257
+ self.logger.info(f"Reading file {file_index + 1}: {file_path}")
258
+
259
+ read_methods = {
260
+ "pdf": file_reader.read_pdf,
261
+ "txt": file_reader.read_txt,
262
+ "docx": file_reader.read_docx,
263
+ "srt": file_reader.read_srt,
264
+ "tex": file_reader.read_tex_from_url,
265
+ }
266
+ if file_type not in read_methods:
267
+ self.logger.warning(f"Unsupported file type: {file_type}")
268
+ return
269
+
270
+ try:
271
+ documents = read_methods[file_type](file_path)
272
+ self.process_documents(
273
+ documents, file_path, file_type, "file", addl_metadata
274
+ )
275
+ except Exception as e:
276
+ self.logger.error(f"Error processing file {file_name}: {str(e)}")
277
+
278
+ def process_weblink(self, link, link_index, file_reader, addl_metadata):
279
+ if link in self.document_data:
280
+ return
281
+
282
+ self.logger.info(f"Reading link {link_index + 1} : {link}")
283
+
284
+ try:
285
+ if "youtube" in link:
286
+ documents = file_reader.read_youtube_transcript(link)
287
+ else:
288
+ documents = file_reader.read_html(link)
289
+
290
+ self.process_documents(documents, link, "txt", "link", addl_metadata)
291
+ except Exception as e:
292
+ self.logger.error(f"Error Reading link {link_index + 1} : {link}: {str(e)}")
293
+
294
+ def save_document_data(self):
295
+ if not os.path.exists(f"{self.config['log_chunk_dir']}/docs"):
296
+ os.makedirs(f"{self.config['log_chunk_dir']}/docs")
297
+ self.logger.info(
298
+ f"Creating directory {self.config['log_chunk_dir']}/docs for document data"
299
+ )
300
+ self.logger.info(
301
+ f"Saving document content to {self.config['log_chunk_dir']}/docs/doc_content.json"
302
+ )
303
+ if not os.path.exists(f"{self.config['log_chunk_dir']}/metadata"):
304
+ os.makedirs(f"{self.config['log_chunk_dir']}/metadata")
305
+ self.logger.info(
306
+ f"Creating directory {self.config['log_chunk_dir']}/metadata for document metadata"
307
+ )
308
+ self.logger.info(
309
+ f"Saving document metadata to {self.config['log_chunk_dir']}/metadata/doc_metadata.json"
310
+ )
311
+ with open(
312
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "w"
313
+ ) as json_file:
314
+ json.dump(self.document_data, json_file, indent=4)
315
+ with open(
316
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "w"
317
+ ) as json_file:
318
+ json.dump(self.document_metadata, json_file, indent=4)
319
+
320
+ def load_document_data(self):
321
+ with open(
322
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
323
+ ) as json_file:
324
+ self.document_data = json.load(json_file)
325
+ with open(
326
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
327
+ ) as json_file:
328
+ self.document_metadata = json.load(json_file)
329
 
330
 
331
  class DataLoader:
332
+ def __init__(self, config, logger=None):
333
+ self.file_reader = FileReader(logger=logger)
334
+ self.chunk_processor = ChunkProcessor(config, logger=logger)
335
 
336
  def get_chunks(self, uploaded_files, weblinks):
337
+ return self.chunk_processor.chunk_docs(
338
  self.file_reader, uploaded_files, weblinks
339
  )
340
+
341
+
342
+ if __name__ == "__main__":
343
+ import yaml
344
+
345
+ logger = logging.getLogger(__name__)
346
+ logger.setLevel(logging.INFO)
347
+
348
+ with open("../code/modules/config/config.yml", "r") as f:
349
+ config = yaml.safe_load(f)
350
+
351
+ data_loader = DataLoader(config, logger=logger)
352
+ document_chunks, document_names, documents, document_metadata = (
353
+ data_loader.get_chunks(
354
+ [],
355
+ ["https://dl4ds.github.io/sp2024/"],
356
+ )
357
+ )
358
+
359
+ print(document_names)
360
+ print(len(document_chunks))
code/modules/dataloader/helpers.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from tqdm import tqdm
4
+
5
+
6
+ def get_urls_from_file(file_path: str):
7
+ """
8
+ Function to get urls from a file
9
+ """
10
+ with open(file_path, "r") as f:
11
+ urls = f.readlines()
12
+ urls = [url.strip() for url in urls]
13
+ return urls
14
+
15
+
16
+ def get_base_url(url):
17
+ parsed_url = urlparse(url)
18
+ base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
19
+ return base_url
20
+
21
+
22
+ def get_metadata(lectures_url, schedule_url):
23
+ """
24
+ Function to get the lecture metadata from the lectures and schedule URLs.
25
+ """
26
+ lecture_metadata = {}
27
+
28
+ # Get the main lectures page content
29
+ r_lectures = requests.get(lectures_url)
30
+ soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
31
+
32
+ # Get the main schedule page content
33
+ r_schedule = requests.get(schedule_url)
34
+ soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
35
+
36
+ # Find all lecture blocks
37
+ lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
38
+
39
+ # Create a mapping from slides link to date
40
+ date_mapping = {}
41
+ schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
42
+ for row in schedule_rows:
43
+ try:
44
+ date = (
45
+ row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
46
+ )
47
+ description_div = row.find("div", {"data-label": "Description"})
48
+ slides_link_tag = description_div.find("a", title="Download slides")
49
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
50
+ slides_link = (
51
+ f"https://dl4ds.github.io{slides_link}" if slides_link else None
52
+ )
53
+ if slides_link:
54
+ date_mapping[slides_link] = date
55
+ except Exception as e:
56
+ print(f"Error processing schedule row: {e}")
57
+ continue
58
+
59
+ for block in lecture_blocks:
60
+ try:
61
+ # Extract the lecture title
62
+ title = block.find("span", style="font-weight: bold;").text.strip()
63
+
64
+ # Extract the TL;DR
65
+ tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
66
+
67
+ # Extract the link to the slides
68
+ slides_link_tag = block.find("a", title="Download slides")
69
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
70
+ slides_link = (
71
+ f"https://dl4ds.github.io{slides_link}" if slides_link else None
72
+ )
73
+
74
+ # Extract the link to the lecture recording
75
+ recording_link_tag = block.find("a", title="Download lecture recording")
76
+ recording_link = (
77
+ recording_link_tag["href"].strip() if recording_link_tag else None
78
+ )
79
+
80
+ # Extract suggested readings or summary if available
81
+ suggested_readings_tag = block.find("p", text="Suggested Readings:")
82
+ if suggested_readings_tag:
83
+ suggested_readings = suggested_readings_tag.find_next_sibling("ul")
84
+ if suggested_readings:
85
+ suggested_readings = suggested_readings.get_text(
86
+ separator="\n"
87
+ ).strip()
88
+ else:
89
+ suggested_readings = "No specific readings provided."
90
+ else:
91
+ suggested_readings = "No specific readings provided."
92
+
93
+ # Get the date from the schedule
94
+ date = date_mapping.get(slides_link, "No date available")
95
+
96
+ # Add to the dictionary
97
+ lecture_metadata[slides_link] = {
98
+ "date": date,
99
+ "tldr": tldr,
100
+ "title": title,
101
+ "lecture_recording": recording_link,
102
+ "suggested_readings": suggested_readings,
103
+ }
104
+ except Exception as e:
105
+ print(f"Error processing block: {e}")
106
+ continue
107
+
108
+ return lecture_metadata
code/modules/{helpers.py β†’ dataloader/webpage_crawler.py} RENAMED
@@ -1,25 +1,9 @@
1
- import requests
2
- from bs4 import BeautifulSoup
3
- from tqdm import tqdm
4
- import chainlit as cl
5
- from langchain import PromptTemplate
6
  import requests
7
  from bs4 import BeautifulSoup
8
  from urllib.parse import urlparse, urljoin, urldefrag
9
- import asyncio
10
- import aiohttp
11
- from aiohttp import ClientSession
12
- from typing import Dict, Any, List
13
-
14
- try:
15
- from modules.constants import *
16
- except:
17
- from constants import *
18
-
19
- """
20
- Ref: https://python.plainenglish.io/scraping-the-subpages-on-a-website-ea2d4e3db113
21
- """
22
-
23
 
24
  class WebpageCrawler:
25
  def __init__(self):
@@ -129,209 +113,3 @@ class WebpageCrawler:
129
  # Strip the fragment identifier
130
  defragged_url, _ = urldefrag(url)
131
  return defragged_url
132
-
133
-
134
- def get_urls_from_file(file_path: str):
135
- """
136
- Function to get urls from a file
137
- """
138
- with open(file_path, "r") as f:
139
- urls = f.readlines()
140
- urls = [url.strip() for url in urls]
141
- return urls
142
-
143
-
144
- def get_base_url(url):
145
- parsed_url = urlparse(url)
146
- base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
147
- return base_url
148
-
149
-
150
- def get_prompt(config):
151
- if config["llm_params"]["use_history"]:
152
- if config["llm_params"]["llm_loader"] == "local_llm":
153
- custom_prompt_template = tinyllama_prompt_template_with_history
154
- elif config["llm_params"]["llm_loader"] == "openai":
155
- custom_prompt_template = openai_prompt_template_with_history
156
- # else:
157
- # custom_prompt_template = tinyllama_prompt_template_with_history # default
158
- prompt = PromptTemplate(
159
- template=custom_prompt_template,
160
- input_variables=["context", "chat_history", "question"],
161
- )
162
- else:
163
- if config["llm_params"]["llm_loader"] == "local_llm":
164
- custom_prompt_template = tinyllama_prompt_template
165
- elif config["llm_params"]["llm_loader"] == "openai":
166
- custom_prompt_template = openai_prompt_template
167
- # else:
168
- # custom_prompt_template = tinyllama_prompt_template
169
- prompt = PromptTemplate(
170
- template=custom_prompt_template,
171
- input_variables=["context", "question"],
172
- )
173
- return prompt
174
-
175
-
176
- def get_sources(res, answer):
177
- source_elements = []
178
- source_dict = {} # Dictionary to store URL elements
179
-
180
- for idx, source in enumerate(res["source_documents"]):
181
- source_metadata = source.metadata
182
- url = source_metadata["source"]
183
- score = source_metadata.get("score", "N/A")
184
- page = source_metadata.get("page", 1)
185
-
186
- lecture_tldr = source_metadata.get("tldr", "N/A")
187
- lecture_recording = source_metadata.get("lecture_recording", "N/A")
188
- suggested_readings = source_metadata.get("suggested_readings", "N/A")
189
- date = source_metadata.get("date", "N/A")
190
-
191
- source_type = source_metadata.get("source_type", "N/A")
192
-
193
- url_name = f"{url}_{page}"
194
- if url_name not in source_dict:
195
- source_dict[url_name] = {
196
- "text": source.page_content,
197
- "url": url,
198
- "score": score,
199
- "page": page,
200
- "lecture_tldr": lecture_tldr,
201
- "lecture_recording": lecture_recording,
202
- "suggested_readings": suggested_readings,
203
- "date": date,
204
- "source_type": source_type,
205
- }
206
- else:
207
- source_dict[url_name]["text"] += f"\n\n{source.page_content}"
208
-
209
- # First, display the answer
210
- full_answer = "**Answer:**\n"
211
- full_answer += answer
212
-
213
- # Then, display the sources
214
- full_answer += "\n\n**Sources:**\n"
215
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
216
- full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
217
-
218
- name = f"Source {idx + 1} Text\n"
219
- full_answer += name
220
- source_elements.append(
221
- cl.Text(name=name, content=source_data["text"], display="side")
222
- )
223
-
224
- # Add a PDF element if the source is a PDF file
225
- if source_data["url"].lower().endswith(".pdf"):
226
- name = f"Source {idx + 1} PDF\n"
227
- full_answer += name
228
- pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
229
- source_elements.append(cl.Pdf(name=name, url=pdf_url, display="side"))
230
-
231
- full_answer += "\n**Metadata:**\n"
232
- for idx, (url_name, source_data) in enumerate(source_dict.items()):
233
- full_answer += f"\nSource {idx + 1} Metadata:\n"
234
- source_elements.append(
235
- cl.Text(
236
- name=f"Source {idx + 1} Metadata",
237
- content=f"Source: {source_data['url']}\n"
238
- f"Page: {source_data['page']}\n"
239
- f"Type: {source_data['source_type']}\n"
240
- f"Date: {source_data['date']}\n"
241
- f"TL;DR: {source_data['lecture_tldr']}\n"
242
- f"Lecture Recording: {source_data['lecture_recording']}\n"
243
- f"Suggested Readings: {source_data['suggested_readings']}\n",
244
- display="side",
245
- )
246
- )
247
-
248
- return full_answer, source_elements
249
-
250
-
251
- def get_metadata(lectures_url, schedule_url):
252
- """
253
- Function to get the lecture metadata from the lectures and schedule URLs.
254
- """
255
- lecture_metadata = {}
256
-
257
- # Get the main lectures page content
258
- r_lectures = requests.get(lectures_url)
259
- soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
260
-
261
- # Get the main schedule page content
262
- r_schedule = requests.get(schedule_url)
263
- soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
264
-
265
- # Find all lecture blocks
266
- lecture_blocks = soup_lectures.find_all("div", class_="lecture-container")
267
-
268
- # Create a mapping from slides link to date
269
- date_mapping = {}
270
- schedule_rows = soup_schedule.find_all("li", class_="table-row-lecture")
271
- for row in schedule_rows:
272
- try:
273
- date = (
274
- row.find("div", {"data-label": "Date"}).get_text(separator=" ").strip()
275
- )
276
- description_div = row.find("div", {"data-label": "Description"})
277
- slides_link_tag = description_div.find("a", title="Download slides")
278
- slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
279
- slides_link = (
280
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
281
- )
282
- if slides_link:
283
- date_mapping[slides_link] = date
284
- except Exception as e:
285
- print(f"Error processing schedule row: {e}")
286
- continue
287
-
288
- for block in lecture_blocks:
289
- try:
290
- # Extract the lecture title
291
- title = block.find("span", style="font-weight: bold;").text.strip()
292
-
293
- # Extract the TL;DR
294
- tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
295
-
296
- # Extract the link to the slides
297
- slides_link_tag = block.find("a", title="Download slides")
298
- slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
299
- slides_link = (
300
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
301
- )
302
-
303
- # Extract the link to the lecture recording
304
- recording_link_tag = block.find("a", title="Download lecture recording")
305
- recording_link = (
306
- recording_link_tag["href"].strip() if recording_link_tag else None
307
- )
308
-
309
- # Extract suggested readings or summary if available
310
- suggested_readings_tag = block.find("p", text="Suggested Readings:")
311
- if suggested_readings_tag:
312
- suggested_readings = suggested_readings_tag.find_next_sibling("ul")
313
- if suggested_readings:
314
- suggested_readings = suggested_readings.get_text(
315
- separator="\n"
316
- ).strip()
317
- else:
318
- suggested_readings = "No specific readings provided."
319
- else:
320
- suggested_readings = "No specific readings provided."
321
-
322
- # Get the date from the schedule
323
- date = date_mapping.get(slides_link, "No date available")
324
-
325
- # Add to the dictionary
326
- lecture_metadata[slides_link] = {
327
- "date": date,
328
- "tldr": tldr,
329
- "title": title,
330
- "lecture_recording": recording_link,
331
- "suggested_readings": suggested_readings,
332
- }
333
- except Exception as e:
334
- print(f"Error processing block: {e}")
335
- continue
336
-
337
- return lecture_metadata
 
1
+ import aiohttp
2
+ from aiohttp import ClientSession
3
+ import asyncio
 
 
4
  import requests
5
  from bs4 import BeautifulSoup
6
  from urllib.parse import urlparse, urljoin, urldefrag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class WebpageCrawler:
9
  def __init__(self):
 
113
  # Strip the fragment identifier
114
  defragged_url, _ = urldefrag(url)
115
  return defragged_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/retriever/__init__.py ADDED
File without changes
code/modules/retriever/base.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # template for retriever classes
2
+
3
+
4
+ class BaseRetriever:
5
+ def __init__(self, config):
6
+ self.config = config
7
+
8
+ def return_retriever(self):
9
+ """
10
+ Returns the retriever object
11
+ """
12
+ raise NotImplementedError
code/modules/retriever/chroma_retriever.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class ChromaRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ # search_type="similarity_score_threshold",
13
+ # search_kwargs={
14
+ # "score_threshold": self.config["vectorstore"][
15
+ # "score_threshold"
16
+ # ],
17
+ # "k": self.config["vectorstore"]["search_top_k"],
18
+ # },
19
+ search_kwargs={
20
+ "k": config["vectorstore"]["search_top_k"],
21
+ },
22
+ )
23
+
24
+ return retriever
code/modules/retriever/colbert_retriever.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseRetriever
2
+
3
+
4
+ class ColbertRetriever(BaseRetriever):
5
+ def __init__(self):
6
+ pass
7
+
8
+ def return_retriever(self, db, config):
9
+ retriever = db.as_langchain_retriever(k=config["vectorstore"]["search_top_k"])
10
+ return retriever
code/modules/retriever/faiss_retriever.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class FaissRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ # search_type="similarity_score_threshold",
13
+ # search_kwargs={
14
+ # "score_threshold": self.config["vectorstore"][
15
+ # "score_threshold"
16
+ # ],
17
+ # "k": self.config["vectorstore"]["search_top_k"],
18
+ # },
19
+ search_kwargs={
20
+ "k": config["vectorstore"]["search_top_k"],
21
+ },
22
+ )
23
+ return retriever
code/modules/retriever/helpers.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema.vectorstore import VectorStoreRetriever
2
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
3
+ from langchain.schema.document import Document
4
+ from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
5
+ from typing import List
6
+
7
+
8
+ class VectorStoreRetrieverScore(VectorStoreRetriever):
9
+
10
+ # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
11
+ def _get_relevant_documents(
12
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
13
+ ) -> List[Document]:
14
+ docs_and_similarities = (
15
+ self.vectorstore.similarity_search_with_relevance_scores(
16
+ query, **self.search_kwargs
17
+ )
18
+ )
19
+ # Make the score part of the document metadata
20
+ for doc, similarity in docs_and_similarities:
21
+ doc.metadata["score"] = similarity
22
+
23
+ docs = [doc for doc, _ in docs_and_similarities]
24
+ return docs
25
+
26
+ async def _aget_relevant_documents(
27
+ self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
28
+ ) -> List[Document]:
29
+ docs_and_similarities = (
30
+ self.vectorstore.similarity_search_with_relevance_scores(
31
+ query, **self.search_kwargs
32
+ )
33
+ )
34
+ # Make the score part of the document metadata
35
+ for doc, similarity in docs_and_similarities:
36
+ doc.metadata["score"] = similarity
37
+
38
+ docs = [doc for doc, _ in docs_and_similarities]
39
+ return docs
code/modules/retriever/raptor_retriever.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .helpers import VectorStoreRetrieverScore
2
+ from .base import BaseRetriever
3
+
4
+
5
+ class RaptorRetriever(BaseRetriever):
6
+ def __init__(self):
7
+ pass
8
+
9
+ def return_retriever(self, db, config):
10
+ retriever = VectorStoreRetrieverScore(
11
+ vectorstore=db,
12
+ search_kwargs={
13
+ "k": config["vectorstore"]["search_top_k"],
14
+ },
15
+ )
16
+ return retriever
code/modules/retriever/retriever.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.retriever.faiss_retriever import FaissRetriever
2
+ from modules.retriever.chroma_retriever import ChromaRetriever
3
+ from modules.retriever.colbert_retriever import ColbertRetriever
4
+ from modules.retriever.raptor_retriever import RaptorRetriever
5
+
6
+
7
+ class Retriever:
8
+ def __init__(self, config):
9
+ self.config = config
10
+ self.retriever_classes = {
11
+ "FAISS": FaissRetriever,
12
+ "Chroma": ChromaRetriever,
13
+ "RAGatouille": ColbertRetriever,
14
+ "RAPTOR": RaptorRetriever,
15
+ }
16
+ self._create_retriever()
17
+
18
+ def _create_retriever(self):
19
+ db_option = self.config["vectorstore"]["db_option"]
20
+ retriever_class = self.retriever_classes.get(db_option)
21
+ if not retriever_class:
22
+ raise ValueError(f"Invalid db_option: {db_option}")
23
+ self.retriever = retriever_class()
24
+
25
+ def _return_retriever(self, db):
26
+ return self.retriever.return_retriever(db, self.config)
code/modules/vector_db.py DELETED
@@ -1,226 +0,0 @@
1
- import logging
2
- import os
3
- import yaml
4
- from langchain_community.vectorstores import FAISS, Chroma
5
- from langchain.schema.vectorstore import VectorStoreRetriever
6
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
- from langchain.schema.document import Document
8
- from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
9
- from ragatouille import RAGPretrainedModel
10
-
11
- try:
12
- from modules.embedding_model_loader import EmbeddingModelLoader
13
- from modules.data_loader import DataLoader
14
- from modules.constants import *
15
- from modules.helpers import *
16
- except:
17
- from embedding_model_loader import EmbeddingModelLoader
18
- from data_loader import DataLoader
19
- from constants import *
20
- from helpers import *
21
-
22
- from typing import List
23
-
24
-
25
- class VectorDBScore(VectorStoreRetriever):
26
-
27
- # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
28
- def _get_relevant_documents(
29
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
30
- ) -> List[Document]:
31
- docs_and_similarities = (
32
- self.vectorstore.similarity_search_with_relevance_scores(
33
- query, **self.search_kwargs
34
- )
35
- )
36
- # Make the score part of the document metadata
37
- for doc, similarity in docs_and_similarities:
38
- doc.metadata["score"] = similarity
39
-
40
- docs = [doc for doc, _ in docs_and_similarities]
41
- return docs
42
-
43
- async def _aget_relevant_documents(
44
- self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
45
- ) -> List[Document]:
46
- docs_and_similarities = (
47
- self.vectorstore.similarity_search_with_relevance_scores(
48
- query, **self.search_kwargs
49
- )
50
- )
51
- # Make the score part of the document metadata
52
- for doc, similarity in docs_and_similarities:
53
- doc.metadata["score"] = similarity
54
-
55
- docs = [doc for doc, _ in docs_and_similarities]
56
- return docs
57
-
58
-
59
- class VectorDB:
60
- def __init__(self, config, logger=None):
61
- self.config = config
62
- self.db_option = config["embedding_options"]["db_option"]
63
- self.document_names = None
64
- self.webpage_crawler = WebpageCrawler()
65
-
66
- # Set up logging to both console and a file
67
- if logger is None:
68
- self.logger = logging.getLogger(__name__)
69
- self.logger.setLevel(logging.INFO)
70
-
71
- # Console Handler
72
- console_handler = logging.StreamHandler()
73
- console_handler.setLevel(logging.INFO)
74
- formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
75
- console_handler.setFormatter(formatter)
76
- self.logger.addHandler(console_handler)
77
-
78
- # File Handler
79
- log_file_path = "vector_db.log" # Change this to your desired log file path
80
- file_handler = logging.FileHandler(log_file_path, mode="w")
81
- file_handler.setLevel(logging.INFO)
82
- file_handler.setFormatter(formatter)
83
- self.logger.addHandler(file_handler)
84
- else:
85
- self.logger = logger
86
-
87
- self.logger.info("VectorDB instance instantiated")
88
-
89
- def load_files(self):
90
- files = os.listdir(self.config["embedding_options"]["data_path"])
91
- files = [
92
- os.path.join(self.config["embedding_options"]["data_path"], file)
93
- for file in files
94
- ]
95
- urls = get_urls_from_file(self.config["embedding_options"]["url_file_path"])
96
- if self.config["embedding_options"]["expand_urls"]:
97
- all_urls = []
98
- for url in urls:
99
- loop = asyncio.get_event_loop()
100
- all_urls.extend(
101
- loop.run_until_complete(
102
- self.webpage_crawler.get_all_pages(
103
- url, url
104
- ) # only get child urls, if you want to get all urls, replace the second argument with the base url
105
- )
106
- )
107
- urls = all_urls
108
- return files, urls
109
-
110
- def create_embedding_model(self):
111
- self.logger.info("Creating embedding function")
112
- self.embedding_model_loader = EmbeddingModelLoader(self.config)
113
- self.embedding_model = self.embedding_model_loader.load_embedding_model()
114
-
115
- def initialize_database(
116
- self,
117
- document_chunks: list,
118
- document_names: list,
119
- documents: list,
120
- document_metadata: list,
121
- ):
122
- if self.db_option in ["FAISS", "Chroma"]:
123
- self.create_embedding_model()
124
- # Track token usage
125
- self.logger.info("Initializing vector_db")
126
- self.logger.info("\tUsing {} as db_option".format(self.db_option))
127
- if self.db_option == "FAISS":
128
- self.vector_db = FAISS.from_documents(
129
- documents=document_chunks, embedding=self.embedding_model
130
- )
131
- elif self.db_option == "Chroma":
132
- self.vector_db = Chroma.from_documents(
133
- documents=document_chunks,
134
- embedding=self.embedding_model,
135
- persist_directory=os.path.join(
136
- self.config["embedding_options"]["db_path"],
137
- "db_"
138
- + self.config["embedding_options"]["db_option"]
139
- + "_"
140
- + self.config["embedding_options"]["model"],
141
- ),
142
- )
143
- elif self.db_option == "RAGatouille":
144
- self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
145
- index_path = self.RAG.index(
146
- index_name="new_idx",
147
- collection=documents,
148
- document_ids=document_names,
149
- document_metadatas=document_metadata,
150
- )
151
- self.logger.info("Completed initializing vector_db")
152
-
153
- def create_database(self):
154
- data_loader = DataLoader(self.config)
155
- self.logger.info("Loading data")
156
- files, urls = self.load_files()
157
- files, webpages = self.webpage_crawler.clean_url_list(urls)
158
- if "storage/data/urls.txt" in files:
159
- files.remove("storage/data/urls.txt")
160
- document_chunks, document_names, documents, document_metadata = (
161
- data_loader.get_chunks(files, webpages)
162
- )
163
- self.logger.info("Completed loading data")
164
- self.initialize_database(
165
- document_chunks, document_names, documents, document_metadata
166
- )
167
-
168
- def save_database(self):
169
- if self.db_option == "FAISS":
170
- self.vector_db.save_local(
171
- os.path.join(
172
- self.config["embedding_options"]["db_path"],
173
- "db_"
174
- + self.config["embedding_options"]["db_option"]
175
- + "_"
176
- + self.config["embedding_options"]["model"],
177
- )
178
- )
179
- elif self.db_option == "Chroma":
180
- # db is saved in the persist directory during initialization
181
- pass
182
- elif self.db_option == "RAGatouille":
183
- # index is saved during initialization
184
- pass
185
- self.logger.info("Saved database")
186
-
187
- def load_database(self):
188
- self.create_embedding_model()
189
- if self.db_option == "FAISS":
190
- self.vector_db = FAISS.load_local(
191
- os.path.join(
192
- self.config["embedding_options"]["db_path"],
193
- "db_"
194
- + self.config["embedding_options"]["db_option"]
195
- + "_"
196
- + self.config["embedding_options"]["model"],
197
- ),
198
- self.embedding_model,
199
- allow_dangerous_deserialization=True,
200
- )
201
- elif self.db_option == "Chroma":
202
- self.vector_db = Chroma(
203
- persist_directory=os.path.join(
204
- self.config["embedding_options"]["db_path"],
205
- "db_"
206
- + self.config["embedding_options"]["db_option"]
207
- + "_"
208
- + self.config["embedding_options"]["model"],
209
- ),
210
- embedding_function=self.embedding_model,
211
- )
212
- elif self.db_option == "RAGatouille":
213
- self.vector_db = RAGPretrainedModel.from_index(
214
- ".ragatouille/colbert/indexes/new_idx"
215
- )
216
- self.logger.info("Loaded database")
217
- return self.vector_db
218
-
219
-
220
- if __name__ == "__main__":
221
- with open("code/config.yml", "r") as f:
222
- config = yaml.safe_load(f)
223
- print(config)
224
- vector_db = VectorDB(config)
225
- vector_db.create_database()
226
- vector_db.save_database()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/vectorstore/__init__.py ADDED
File without changes
code/modules/vectorstore/base.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # template for vector store classes
2
+
3
+
4
+ class VectorStoreBase:
5
+ def __init__(self, config):
6
+ self.config = config
7
+
8
+ def _init_vector_db(self):
9
+ """
10
+ Creates a vector store object
11
+ """
12
+ raise NotImplementedError
13
+
14
+ def create_database(self):
15
+ """
16
+ Populates the vector store with documents
17
+ """
18
+ raise NotImplementedError
19
+
20
+ def load_database(self):
21
+ """
22
+ Loads the vector store from disk
23
+ """
24
+ raise NotImplementedError
25
+
26
+ def as_retriever(self):
27
+ """
28
+ Returns the vector store as a retriever
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def __str__(self):
33
+ return self.__class__.__name__
code/modules/vectorstore/chroma.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class ChromaVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.chroma = Chroma()
13
+
14
+ def create_database(self, document_chunks, embedding_model):
15
+ self.vectorstore = self.chroma.from_documents(
16
+ documents=document_chunks,
17
+ embedding=embedding_model,
18
+ persist_directory=os.path.join(
19
+ self.config["vectorstore"]["db_path"],
20
+ "db_"
21
+ + self.config["vectorstore"]["db_option"]
22
+ + "_"
23
+ + self.config["vectorstore"]["model"],
24
+ ),
25
+ )
26
+
27
+ def load_database(self, embedding_model):
28
+ self.vectorstore = Chroma(
29
+ persist_directory=os.path.join(
30
+ self.config["vectorstore"]["db_path"],
31
+ "db_"
32
+ + self.config["vectorstore"]["db_option"]
33
+ + "_"
34
+ + self.config["vectorstore"]["model"],
35
+ ),
36
+ embedding_function=embedding_model,
37
+ )
38
+ return self.vectorstore
39
+
40
+ def as_retriever(self):
41
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/colbert.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ragatouille import RAGPretrainedModel
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class ColbertVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.colbert = RAGPretrainedModel.from_pretrained(
13
+ "colbert-ir/colbertv2.0",
14
+ index_root=os.path.join(
15
+ self.config["vectorstore"]["db_path"],
16
+ "db_" + self.config["vectorstore"]["db_option"],
17
+ ),
18
+ )
19
+
20
+ def create_database(self, documents, document_names, document_metadata):
21
+ index_path = self.colbert.index(
22
+ index_name="new_idx",
23
+ collection=documents,
24
+ document_ids=document_names,
25
+ document_metadatas=document_metadata,
26
+ )
27
+
28
+ def load_database(self):
29
+ path = os.path.join(
30
+ self.config["vectorstore"]["db_path"],
31
+ "db_" + self.config["vectorstore"]["db_option"],
32
+ )
33
+ self.vectorstore = RAGPretrainedModel.from_index(
34
+ f"{path}/colbert/indexes/new_idx"
35
+ )
36
+ return self.vectorstore
37
+
38
+ def as_retriever(self):
39
+ return self.vectorstore.as_retriever()
code/modules/{embedding_model_loader.py β†’ vectorstore/embedding_model_loader.py} RENAMED
@@ -2,10 +2,7 @@ from langchain_community.embeddings import OpenAIEmbeddings
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.embeddings import LlamaCppEmbeddings
4
 
5
- try:
6
- from modules.constants import *
7
- except:
8
- from constants import *
9
  import os
10
 
11
 
@@ -14,19 +11,19 @@ class EmbeddingModelLoader:
14
  self.config = config
15
 
16
  def load_embedding_model(self):
17
- if self.config["embedding_options"]["model"] in ["text-embedding-ada-002"]:
18
  embedding_model = OpenAIEmbeddings(
19
  deployment="SL-document_embedder",
20
- model=self.config["embedding_options"]["model"],
21
  show_progress_bar=True,
22
  openai_api_key=OPENAI_API_KEY,
23
  disallowed_special=(),
24
  )
25
  else:
26
  embedding_model = HuggingFaceEmbeddings(
27
- model_name=self.config["embedding_options"]["model"],
28
  model_kwargs={
29
- "device": "cpu",
30
  "token": f"{HUGGINGFACE_TOKEN}",
31
  "trust_remote_code": True,
32
  },
 
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.embeddings import LlamaCppEmbeddings
4
 
5
+ from modules.config.constants import *
 
 
 
6
  import os
7
 
8
 
 
11
  self.config = config
12
 
13
  def load_embedding_model(self):
14
+ if self.config["vectorstore"]["model"] in ["text-embedding-ada-002"]:
15
  embedding_model = OpenAIEmbeddings(
16
  deployment="SL-document_embedder",
17
+ model=self.config["vectorestore"]["model"],
18
  show_progress_bar=True,
19
  openai_api_key=OPENAI_API_KEY,
20
  disallowed_special=(),
21
  )
22
  else:
23
  embedding_model = HuggingFaceEmbeddings(
24
+ model_name=self.config["vectorstore"]["model"],
25
  model_kwargs={
26
+ "device": f"{self.config['device']}",
27
  "token": f"{HUGGINGFACE_TOKEN}",
28
  "trust_remote_code": True,
29
  },
code/modules/vectorstore/faiss.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from modules.vectorstore.base import VectorStoreBase
3
+ import os
4
+
5
+
6
+ class FaissVectorStore(VectorStoreBase):
7
+ def __init__(self, config):
8
+ self.config = config
9
+ self._init_vector_db()
10
+
11
+ def _init_vector_db(self):
12
+ self.faiss = FAISS(
13
+ embedding_function=None, index=0, index_to_docstore_id={}, docstore={}
14
+ )
15
+
16
+ def create_database(self, document_chunks, embedding_model):
17
+ self.vectorstore = self.faiss.from_documents(
18
+ documents=document_chunks, embedding=embedding_model
19
+ )
20
+ self.vectorstore.save_local(
21
+ os.path.join(
22
+ self.config["vectorstore"]["db_path"],
23
+ "db_"
24
+ + self.config["vectorstore"]["db_option"]
25
+ + "_"
26
+ + self.config["vectorstore"]["model"],
27
+ )
28
+ )
29
+
30
+ def load_database(self, embedding_model):
31
+ self.vectorstore = self.faiss.load_local(
32
+ os.path.join(
33
+ self.config["vectorstore"]["db_path"],
34
+ "db_"
35
+ + self.config["vectorstore"]["db_option"]
36
+ + "_"
37
+ + self.config["vectorstore"]["model"],
38
+ ),
39
+ embedding_model,
40
+ allow_dangerous_deserialization=True,
41
+ )
42
+ return self.vectorstore
43
+
44
+ def as_retriever(self):
45
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/helpers.py ADDED
File without changes
code/modules/vectorstore/raptor.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code modified from https://github.com/langchain-ai/langchain/blob/master/cookbook/RAPTOR.ipynb
2
+
3
+ from typing import Dict, List, Optional, Tuple
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import umap
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from sklearn.mixture import GaussianMixture
11
+ from langchain_community.chat_models import ChatOpenAI
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from modules.vectorstore.base import VectorStoreBase
15
+
16
+ RANDOM_SEED = 42
17
+
18
+
19
+ class RAPTORVectoreStore(VectorStoreBase):
20
+ def __init__(self, config, documents=[], text_splitter=None, embedding_model=None):
21
+ self.documents = documents
22
+ self.config = config
23
+ self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
24
+ chunk_size=self.config["splitter_options"]["chunk_size"],
25
+ chunk_overlap=self.config["splitter_options"]["chunk_overlap"],
26
+ separators=self.config["splitter_options"]["chunk_separators"],
27
+ disallowed_special=(),
28
+ )
29
+ self.embd = embedding_model
30
+ self.model = ChatOpenAI(
31
+ model="gpt-3.5-turbo",
32
+ )
33
+
34
+ def concat_documents(self, documents):
35
+ d_sorted = sorted(documents, key=lambda x: x.metadata["source"])
36
+ d_reversed = list(reversed(d_sorted))
37
+ concatenated_content = "\n\n\n --- \n\n\n".join(
38
+ [doc.page_content for doc in d_reversed]
39
+ )
40
+ return concatenated_content
41
+
42
+ def split_documents(self, documents):
43
+ concatenated_content = self.concat_documents(documents)
44
+ texts_split = self.text_splitter.split_text(concatenated_content)
45
+ return texts_split
46
+
47
+ def add_documents(self, documents):
48
+ self.documents.extend(documents)
49
+
50
+ def global_cluster_embeddings(
51
+ self,
52
+ embeddings: np.ndarray,
53
+ dim: int,
54
+ n_neighbors: Optional[int] = None,
55
+ metric: str = "cosine",
56
+ ) -> np.ndarray:
57
+ """
58
+ Perform global dimensionality reduction on the embeddings using UMAP.
59
+
60
+ Parameters:
61
+ - embeddings: The input embeddings as a numpy array.
62
+ - dim: The target dimensionality for the reduced space.
63
+ - n_neighbors: Optional; the number of neighbors to consider for each point.
64
+ If not provided, it defaults to the square root of the number of embeddings.
65
+ - metric: The distance metric to use for UMAP.
66
+
67
+ Returns:
68
+ - A numpy array of the embeddings reduced to the specified dimensionality.
69
+ """
70
+ if n_neighbors is None:
71
+ n_neighbors = int((len(embeddings) - 1) ** 0.5)
72
+ return umap.UMAP(
73
+ n_neighbors=n_neighbors, n_components=dim, metric=metric
74
+ ).fit_transform(embeddings)
75
+
76
+ def local_cluster_embeddings(
77
+ self,
78
+ embeddings: np.ndarray,
79
+ dim: int,
80
+ num_neighbors: int = 10,
81
+ metric: str = "cosine",
82
+ ) -> np.ndarray:
83
+ """
84
+ Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering.
85
+
86
+ Parameters:
87
+ - embeddings: The input embeddings as a numpy array.
88
+ - dim: The target dimensionality for the reduced space.
89
+ - num_neighbors: The number of neighbors to consider for each point.
90
+ - metric: The distance metric to use for UMAP.
91
+
92
+ Returns:
93
+ - A numpy array of the embeddings reduced to the specified dimensionality.
94
+ """
95
+ return umap.UMAP(
96
+ n_neighbors=num_neighbors, n_components=dim, metric=metric
97
+ ).fit_transform(embeddings)
98
+
99
+ def get_optimal_clusters(
100
+ self,
101
+ embeddings: np.ndarray,
102
+ max_clusters: int = 50,
103
+ random_state: int = RANDOM_SEED,
104
+ ) -> int:
105
+ """
106
+ Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model.
107
+
108
+ Parameters:
109
+ - embeddings: The input embeddings as a numpy array.
110
+ - max_clusters: The maximum number of clusters to consider.
111
+ - random_state: Seed for reproducibility.
112
+
113
+ Returns:
114
+ - An integer representing the optimal number of clusters found.
115
+ """
116
+ max_clusters = min(max_clusters, len(embeddings))
117
+ n_clusters = np.arange(1, max_clusters)
118
+ bics = []
119
+ for n in n_clusters:
120
+ gm = GaussianMixture(n_components=n, random_state=random_state)
121
+ gm.fit(embeddings)
122
+ bics.append(gm.bic(embeddings))
123
+ return n_clusters[np.argmin(bics)]
124
+
125
+ def GMM_cluster(
126
+ self, embeddings: np.ndarray, threshold: float, random_state: int = 0
127
+ ):
128
+ """
129
+ Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold.
130
+
131
+ Parameters:
132
+ - embeddings: The input embeddings as a numpy array.
133
+ - threshold: The probability threshold for assigning an embedding to a cluster.
134
+ - random_state: Seed for reproducibility.
135
+
136
+ Returns:
137
+ - A tuple containing the cluster labels and the number of clusters determined.
138
+ """
139
+ n_clusters = self.get_optimal_clusters(embeddings)
140
+ gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
141
+ gm.fit(embeddings)
142
+ probs = gm.predict_proba(embeddings)
143
+ labels = [np.where(prob > threshold)[0] for prob in probs]
144
+ return labels, n_clusters
145
+
146
+ def perform_clustering(
147
+ self,
148
+ embeddings: np.ndarray,
149
+ dim: int,
150
+ threshold: float,
151
+ ) -> List[np.ndarray]:
152
+ """
153
+ Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering
154
+ using a Gaussian Mixture Model, and finally performing local clustering within each global cluster.
155
+
156
+ Parameters:
157
+ - embeddings: The input embeddings as a numpy array.
158
+ - dim: The target dimensionality for UMAP reduction.
159
+ - threshold: The probability threshold for assigning an embedding to a cluster in GMM.
160
+
161
+ Returns:
162
+ - A list of numpy arrays, where each array contains the cluster IDs for each embedding.
163
+ """
164
+ if len(embeddings) <= dim + 1:
165
+ # Avoid clustering when there's insufficient data
166
+ return [np.array([0]) for _ in range(len(embeddings))]
167
+
168
+ # Global dimensionality reduction
169
+ reduced_embeddings_global = self.global_cluster_embeddings(embeddings, dim)
170
+ # Global clustering
171
+ global_clusters, n_global_clusters = self.GMM_cluster(
172
+ reduced_embeddings_global, threshold
173
+ )
174
+
175
+ all_local_clusters = [np.array([]) for _ in range(len(embeddings))]
176
+ total_clusters = 0
177
+
178
+ # Iterate through each global cluster to perform local clustering
179
+ for i in range(n_global_clusters):
180
+ # Extract embeddings belonging to the current global cluster
181
+ global_cluster_embeddings_ = embeddings[
182
+ np.array([i in gc for gc in global_clusters])
183
+ ]
184
+
185
+ if len(global_cluster_embeddings_) == 0:
186
+ continue
187
+ if len(global_cluster_embeddings_) <= dim + 1:
188
+ # Handle small clusters with direct assignment
189
+ local_clusters = [np.array([0]) for _ in global_cluster_embeddings_]
190
+ n_local_clusters = 1
191
+ else:
192
+ # Local dimensionality reduction and clustering
193
+ reduced_embeddings_local = self.local_cluster_embeddings(
194
+ global_cluster_embeddings_, dim
195
+ )
196
+ local_clusters, n_local_clusters = self.GMM_cluster(
197
+ reduced_embeddings_local, threshold
198
+ )
199
+
200
+ # Assign local cluster IDs, adjusting for total clusters already processed
201
+ for j in range(n_local_clusters):
202
+ local_cluster_embeddings_ = global_cluster_embeddings_[
203
+ np.array([j in lc for lc in local_clusters])
204
+ ]
205
+ indices = np.where(
206
+ (embeddings == local_cluster_embeddings_[:, None]).all(-1)
207
+ )[1]
208
+ for idx in indices:
209
+ all_local_clusters[idx] = np.append(
210
+ all_local_clusters[idx], j + total_clusters
211
+ )
212
+
213
+ total_clusters += n_local_clusters
214
+
215
+ return all_local_clusters
216
+
217
+ def embed(self, texts):
218
+ """
219
+ Generate embeddings for a list of text documents.
220
+
221
+ This function assumes the existence of an `embd` object with a method `embed_documents`
222
+ that takes a list of texts and returns their embeddings.
223
+
224
+ Parameters:
225
+ - texts: List[str], a list of text documents to be embedded.
226
+
227
+ Returns:
228
+ - numpy.ndarray: An array of embeddings for the given text documents.
229
+ """
230
+ text_embeddings = self.embd.embed_documents(texts)
231
+ text_embeddings_np = np.array(text_embeddings)
232
+ return text_embeddings_np
233
+
234
+ def embed_cluster_texts(self, texts):
235
+ """
236
+ Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels.
237
+
238
+ This function combines embedding generation and clustering into a single step. It assumes the existence
239
+ of a previously defined `perform_clustering` function that performs clustering on the embeddings.
240
+
241
+ Parameters:
242
+ - texts: List[str], a list of text documents to be processed.
243
+
244
+ Returns:
245
+ - pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels.
246
+ """
247
+ text_embeddings_np = self.embed(texts) # Generate embeddings
248
+ cluster_labels = self.perform_clustering(
249
+ text_embeddings_np, 10, 0.1
250
+ ) # Perform clustering on the embeddings
251
+ df = pd.DataFrame() # Initialize a DataFrame to store the results
252
+ df["text"] = texts # Store original texts
253
+ df["embd"] = list(
254
+ text_embeddings_np
255
+ ) # Store embeddings as a list in the DataFrame
256
+ df["cluster"] = cluster_labels # Store cluster labels
257
+ return df
258
+
259
+ def fmt_txt(self, df: pd.DataFrame) -> str:
260
+ """
261
+ Formats the text documents in a DataFrame into a single string.
262
+
263
+ Parameters:
264
+ - df: DataFrame containing the 'text' column with text documents to format.
265
+
266
+ Returns:
267
+ - A single string where all text documents are joined by a specific delimiter.
268
+ """
269
+ unique_txt = df["text"].tolist()
270
+ return "--- --- \n --- --- ".join(unique_txt)
271
+
272
+ def embed_cluster_summarize_texts(
273
+ self, texts: List[str], level: int
274
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
275
+ """
276
+ Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts,
277
+ clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes
278
+ the content within each cluster.
279
+
280
+ Parameters:
281
+ - texts: A list of text documents to be processed.
282
+ - level: An integer parameter that could define the depth or detail of processing.
283
+
284
+ Returns:
285
+ - Tuple containing two DataFrames:
286
+ 1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments.
287
+ 2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail,
288
+ and the cluster identifiers.
289
+ """
290
+
291
+ # Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns
292
+ df_clusters = self.embed_cluster_texts(texts)
293
+
294
+ # Prepare to expand the DataFrame for easier manipulation of clusters
295
+ expanded_list = []
296
+
297
+ # Expand DataFrame entries to document-cluster pairings for straightforward processing
298
+ for index, row in df_clusters.iterrows():
299
+ for cluster in row["cluster"]:
300
+ expanded_list.append(
301
+ {"text": row["text"], "embd": row["embd"], "cluster": cluster}
302
+ )
303
+
304
+ # Create a new DataFrame from the expanded list
305
+ expanded_df = pd.DataFrame(expanded_list)
306
+
307
+ # Retrieve unique cluster identifiers for processing
308
+ all_clusters = expanded_df["cluster"].unique()
309
+
310
+ print(f"--Generated {len(all_clusters)} clusters--")
311
+
312
+ # Summarization
313
+ template = """Here is content from the course DS598: Deep Learning for Data Science.
314
+
315
+ The content may be form webapge about the course, or lecture content, or any other relevant information.
316
+ If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
317
+
318
+ Give a detailed summary of the content below.
319
+
320
+ Documentation:
321
+ {context}
322
+ """
323
+ prompt = ChatPromptTemplate.from_template(template)
324
+ chain = prompt | self.model | StrOutputParser()
325
+
326
+ # Format text within each cluster for summarization
327
+ summaries = []
328
+ for i in all_clusters:
329
+ df_cluster = expanded_df[expanded_df["cluster"] == i]
330
+ formatted_txt = self.fmt_txt(df_cluster)
331
+ summaries.append(chain.invoke({"context": formatted_txt}))
332
+
333
+ # Create a DataFrame to store summaries with their corresponding cluster and level
334
+ df_summary = pd.DataFrame(
335
+ {
336
+ "summaries": summaries,
337
+ "level": [level] * len(summaries),
338
+ "cluster": list(all_clusters),
339
+ }
340
+ )
341
+
342
+ return df_clusters, df_summary
343
+
344
+ def recursive_embed_cluster_summarize(
345
+ self, texts: List[str], level: int = 1, n_levels: int = 3
346
+ ) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]:
347
+ """
348
+ Recursively embeds, clusters, and summarizes texts up to a specified level or until
349
+ the number of unique clusters becomes 1, storing the results at each level.
350
+
351
+ Parameters:
352
+ - texts: List[str], texts to be processed.
353
+ - level: int, current recursion level (starts at 1).
354
+ - n_levels: int, maximum depth of recursion.
355
+
356
+ Returns:
357
+ - Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion
358
+ levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level.
359
+ """
360
+ results = {} # Dictionary to store results at each level
361
+
362
+ # Perform embedding, clustering, and summarization for the current level
363
+ df_clusters, df_summary = self.embed_cluster_summarize_texts(texts, level)
364
+
365
+ # Store the results of the current level
366
+ results[level] = (df_clusters, df_summary)
367
+
368
+ # Determine if further recursion is possible and meaningful
369
+ unique_clusters = df_summary["cluster"].nunique()
370
+ if level < n_levels and unique_clusters > 1:
371
+ # Use summaries as the input texts for the next level of recursion
372
+ new_texts = df_summary["summaries"].tolist()
373
+ next_level_results = self.recursive_embed_cluster_summarize(
374
+ new_texts, level + 1, n_levels
375
+ )
376
+
377
+ # Merge the results from the next level into the current results dictionary
378
+ results.update(next_level_results)
379
+
380
+ return results
381
+
382
+ def get_vector_db(self):
383
+ """
384
+ Generate a retriever object from a list of documents.
385
+
386
+ Parameters:
387
+ - documents: List of document objects.
388
+
389
+ Returns:
390
+ - A retriever object.
391
+ """
392
+ leaf_texts = self.split_documents(self.documents)
393
+ results = self.recursive_embed_cluster_summarize(
394
+ leaf_texts, level=1, n_levels=10
395
+ )
396
+
397
+ all_texts = leaf_texts.copy()
398
+ # Iterate through the results to extract summaries from each level and add them to all_texts
399
+ for level in sorted(results.keys()):
400
+ # Extract summaries from the current level's DataFrame
401
+ summaries = results[level][1]["summaries"].tolist()
402
+ # Extend all_texts with the summaries from the current level
403
+ all_texts.extend(summaries)
404
+
405
+ # Now, use all_texts to build the vectorstore
406
+ vectorstore = FAISS.from_texts(texts=all_texts, embedding=self.embd)
407
+ return vectorstore
408
+
409
+ def create_database(self, documents, embedding_model):
410
+ self.documents = documents
411
+ self.embd = embedding_model
412
+ self.vectorstore = self.get_vector_db()
413
+ self.vectorstore.save_local(
414
+ os.path.join(
415
+ self.config["vectorstore"]["db_path"],
416
+ "db_"
417
+ + self.config["vectorstore"]["db_option"]
418
+ + "_"
419
+ + self.config["vectorstore"]["model"],
420
+ )
421
+ )
422
+
423
+ def load_database(self, embedding_model):
424
+ self.vectorstore = FAISS.load_local(
425
+ os.path.join(
426
+ self.config["vectorstore"]["db_path"],
427
+ "db_"
428
+ + self.config["vectorstore"]["db_option"]
429
+ + "_"
430
+ + self.config["vectorstore"]["model"],
431
+ ),
432
+ embedding_model,
433
+ allow_dangerous_deserialization=True,
434
+ )
435
+ return self.vectorstore
436
+
437
+ def as_retriever(self):
438
+ return self.vectorstore.as_retriever()
code/modules/vectorstore/store_manager.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.vectorstore.vectorstore import VectorStore
2
+ from modules.vectorstore.helpers import *
3
+ from modules.dataloader.webpage_crawler import WebpageCrawler
4
+ from modules.dataloader.data_loader import DataLoader
5
+ from modules.dataloader.helpers import *
6
+ from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
7
+ import logging
8
+ import os
9
+ import time
10
+ import asyncio
11
+
12
+
13
+ class VectorStoreManager:
14
+ def __init__(self, config, logger=None):
15
+ self.config = config
16
+ self.document_names = None
17
+
18
+ # Set up logging to both console and a file
19
+ self.logger = logger or self._setup_logging()
20
+ self.webpage_crawler = WebpageCrawler()
21
+ self.vector_db = VectorStore(self.config)
22
+
23
+ self.logger.info("VectorDB instance instantiated")
24
+
25
+ def _setup_logging(self):
26
+ logger = logging.getLogger(__name__)
27
+ if not logger.hasHandlers():
28
+ logger.setLevel(logging.INFO)
29
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
30
+
31
+ # Console Handler
32
+ console_handler = logging.StreamHandler()
33
+ console_handler.setLevel(logging.INFO)
34
+ console_handler.setFormatter(formatter)
35
+ logger.addHandler(console_handler)
36
+
37
+ # Ensure log directory exists
38
+ log_directory = self.config["log_dir"]
39
+ os.makedirs(log_directory, exist_ok=True)
40
+
41
+ # File Handler
42
+ log_file_path = os.path.join(log_directory, "vector_db.log")
43
+ file_handler = logging.FileHandler(log_file_path, mode="w")
44
+ file_handler.setLevel(logging.INFO)
45
+ file_handler.setFormatter(formatter)
46
+ logger.addHandler(file_handler)
47
+
48
+ return logger
49
+
50
+ def load_files(self):
51
+
52
+ files = os.listdir(self.config["vectorstore"]["data_path"])
53
+ files = [
54
+ os.path.join(self.config["vectorstore"]["data_path"], file)
55
+ for file in files
56
+ ]
57
+ urls = get_urls_from_file(self.config["vectorstore"]["url_file_path"])
58
+ if self.config["vectorstore"]["expand_urls"]:
59
+ all_urls = []
60
+ for url in urls:
61
+ loop = asyncio.get_event_loop()
62
+ all_urls.extend(
63
+ loop.run_until_complete(
64
+ self.webpage_crawler.get_all_pages(
65
+ url, url
66
+ ) # only get child urls, if you want to get all urls, replace the second argument with the base url
67
+ )
68
+ )
69
+ urls = all_urls
70
+ return files, urls
71
+
72
+ def create_embedding_model(self):
73
+
74
+ self.logger.info("Creating embedding function")
75
+ embedding_model_loader = EmbeddingModelLoader(self.config)
76
+ embedding_model = embedding_model_loader.load_embedding_model()
77
+ return embedding_model
78
+
79
+ def initialize_database(
80
+ self,
81
+ document_chunks: list,
82
+ document_names: list,
83
+ documents: list,
84
+ document_metadata: list,
85
+ ):
86
+ if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
87
+ self.embedding_model = self.create_embedding_model()
88
+ else:
89
+ self.embedding_model = None
90
+
91
+ self.logger.info("Initializing vector_db")
92
+ self.logger.info(
93
+ "\tUsing {} as db_option".format(self.config["vectorstore"]["db_option"])
94
+ )
95
+ self.vector_db._create_database(
96
+ document_chunks,
97
+ document_names,
98
+ documents,
99
+ document_metadata,
100
+ self.embedding_model,
101
+ )
102
+
103
+ def create_database(self):
104
+
105
+ start_time = time.time() # Start time for creating database
106
+ data_loader = DataLoader(self.config, self.logger)
107
+ self.logger.info("Loading data")
108
+ files, urls = self.load_files()
109
+ files, webpages = self.webpage_crawler.clean_url_list(urls)
110
+ self.logger.info(f"Number of files: {len(files)}")
111
+ self.logger.info(f"Number of webpages: {len(webpages)}")
112
+ if f"{self.config['vectorstore']['url_file_path']}" in files:
113
+ files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
114
+ document_chunks, document_names, documents, document_metadata = (
115
+ data_loader.get_chunks(files, webpages)
116
+ )
117
+ num_documents = len(document_chunks)
118
+ self.logger.info(f"Number of documents in the DB: {num_documents}")
119
+ metadata_keys = list(document_metadata[0].keys())
120
+ self.logger.info(f"Metadata keys: {metadata_keys}")
121
+ self.logger.info("Completed loading data")
122
+ self.initialize_database(
123
+ document_chunks, document_names, documents, document_metadata
124
+ )
125
+ end_time = time.time() # End time for creating database
126
+ self.logger.info("Created database")
127
+ self.logger.info(
128
+ f"Time taken to create database: {end_time - start_time} seconds"
129
+ )
130
+
131
+ def load_database(self):
132
+
133
+ start_time = time.time() # Start time for loading database
134
+ if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
135
+ self.embedding_model = self.create_embedding_model()
136
+ else:
137
+ self.embedding_model = None
138
+ self.loaded_vector_db = self.vector_db._load_database(self.embedding_model)
139
+ end_time = time.time() # End time for loading database
140
+ self.logger.info(
141
+ f"Time taken to load database: {end_time - start_time} seconds"
142
+ )
143
+ self.logger.info("Loaded database")
144
+ return self.loaded_vector_db
145
+
146
+
147
+ if __name__ == "__main__":
148
+ import yaml
149
+
150
+ with open("modules/config/config.yml", "r") as f:
151
+ config = yaml.safe_load(f)
152
+ print(config)
153
+ print(f"Trying to create database with config: {config}")
154
+ vector_db = VectorStoreManager(config)
155
+ vector_db.create_database()
156
+ print("Created database")
157
+
158
+ print(f"Trying to load the database")
159
+ vector_db = VectorStoreManager(config)
160
+ vector_db.load_database()
161
+ print("Loaded database")
162
+
163
+ print(f"View the logs at {config['log_dir']}/vector_db.log")
code/modules/vectorstore/vectorstore.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.vectorstore.faiss import FaissVectorStore
2
+ from modules.vectorstore.chroma import ChromaVectorStore
3
+ from modules.vectorstore.colbert import ColbertVectorStore
4
+ from modules.vectorstore.raptor import RAPTORVectoreStore
5
+
6
+
7
+ class VectorStore:
8
+ def __init__(self, config):
9
+ self.config = config
10
+ self.vectorstore = None
11
+ self.vectorstore_classes = {
12
+ "FAISS": FaissVectorStore,
13
+ "Chroma": ChromaVectorStore,
14
+ "RAGatouille": ColbertVectorStore,
15
+ "RAPTOR": RAPTORVectoreStore,
16
+ }
17
+
18
+ def _create_database(
19
+ self,
20
+ document_chunks,
21
+ document_names,
22
+ documents,
23
+ document_metadata,
24
+ embedding_model,
25
+ ):
26
+ db_option = self.config["vectorstore"]["db_option"]
27
+ vectorstore_class = self.vectorstore_classes.get(db_option)
28
+ if not vectorstore_class:
29
+ raise ValueError(f"Invalid db_option: {db_option}")
30
+
31
+ self.vectorstore = vectorstore_class(self.config)
32
+
33
+ if db_option == "RAGatouille":
34
+ self.vectorstore.create_database(
35
+ documents, document_names, document_metadata
36
+ )
37
+ else:
38
+ self.vectorstore.create_database(document_chunks, embedding_model)
39
+
40
+ def _load_database(self, embedding_model):
41
+ db_option = self.config["vectorstore"]["db_option"]
42
+ vectorstore_class = self.vectorstore_classes.get(db_option)
43
+ if not vectorstore_class:
44
+ raise ValueError(f"Invalid db_option: {db_option}")
45
+
46
+ self.vectorstore = vectorstore_class(self.config)
47
+
48
+ if db_option == "RAGatouille":
49
+ return self.vectorstore.load_database()
50
+ else:
51
+ return self.vectorstore.load_database(embedding_model)
52
+
53
+ def _as_retriever(self):
54
+ return self.vectorstore.as_retriever()
55
+
56
+ def _get_vectorstore(self):
57
+ return self.vectorstore
code/public/acastusphoton-svgrepo-com.svg ADDED
code/public/adv-screen-recorder-svgrepo-com.svg ADDED
code/public/alarmy-svgrepo-com.svg ADDED
public/logo_dark.png β†’ code/public/avatars/ai-tutor.png RENAMED
File without changes
code/public/calendar-samsung-17-svgrepo-com.svg ADDED
public/logo_light.png β†’ code/public/logo_dark.png RENAMED
File without changes
code/public/logo_light.png ADDED