tgd1115 commited on
Commit
976b948
·
verified ·
1 Parent(s): 3c2ff6b

manual deployment

Browse files
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
163
+
164
+ # Mac cache file
165
+ .DS_Store
166
+
README.md CHANGED
@@ -1,14 +1,153 @@
1
- ---
2
- title: Neuro Orion V1
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: gray
6
- sdk: streamlit
7
- sdk_version: 1.41.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: DL Assignment
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Neuro Orion - NYC Taxi Traffic Time Series Anomaly Detection
3
+ emoji: 🐨
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: "1.41.1"
8
+ app_file: src/app.py
9
+ pinned: true
10
+ ---
11
+
12
+ [![Sync to Hugging Face hub](https://github.com/gdtan02/NeuroOrion_Time_Series_Anomaly_Detection/actions/workflows/main.yml/badge.svg)](https://github.com/gdtan02/NeuroOrion_Time_Series_Anomaly_Detection/actions/workflows/main.yml)
13
+
14
+ # NYC Taxi Traffic - Time Series Anomaly Detection
15
+
16
+ ## Project Overview
17
+
18
+ This project is developed for WID3011 Deep Learning Assignment.
19
+
20
+ SDG 8: Sustainable Cities & Communities:
21
+
22
+ This problem examines an anomaly detection challenge using the NYC Taxi Traffic dataset, available on Kaggle
23
+ ([https://www.kaggle.com/datasets/julienjta/nyc-taxi-traffic]) and provided by the NYC Taxi and Limousine
24
+ Commission. The dataset presents a univariate time series of total taxi passenger counts from July 2014 to January
25
+ 2015, aggregated every 30 minutes. It includes five notable anomalies, occurring during the NYC Marathon,
26
+ Thanksgiving, Christmas, New Year’s Day, and a snowstorm.
27
+
28
+ The task involves implementing a complete anomaly detection pipeline: analyzing the NYC Taxi Traffic dataset,
29
+ developing a Long Short Term Memory (LSTM) model to detect outliers and anomaly.
30
+
31
+ **Group Name:**
32
+ Neuro Orion
33
+
34
+ **Group Members:**
35
+ 1. Poo Wei Chien
36
+ 2. Tan Guo Dong
37
+ 3. Tan Zhi Jian
38
+ 4. Sanjivan A/L Balajawahar
39
+ 5. Marvin Chin Yi Kai
40
+
41
+
42
+ ---
43
+ ## Acknowledgements
44
+
45
+ We acknowledge the contributors to the following resources:
46
+ - All the members of Neuro Orion for their contributions to the project.
47
+ - NYC Taxi Traffic dataset provided by NYC Taxi and Limousine Commission.
48
+ - Open-source tools and frameworks like TensorFlow, PyTorch, and Jupyter Notebook.
49
+
50
+ ---
51
+
52
+ ## Installation Guide
53
+
54
+ Follow these steps to set up the project locally:
55
+
56
+ ### 1. Clone the repository to your local machine:
57
+
58
+ Run the following command in your terminal:
59
+
60
+ ```bash
61
+ git clone https://github.com/gdtan02/NeuroOrion_Time_Series_Anomaly_Detection.git
62
+ cd nyc-taxi-anomaly-detection
63
+ ```
64
+
65
+ ### 2. Set up a Python Virtual Environment (Optional):
66
+
67
+ You can use `venv` or `conda` to create and activate a virtual environment to manage dependencies.
68
+
69
+ Using `venv`:
70
+
71
+ For Windows user, run the following command:
72
+ ```commandline
73
+ python -m venv venv
74
+ venv\Scripts\activate
75
+ ```
76
+
77
+ For MacOS/Linux user, run the following command:
78
+ ```commandline
79
+ python3 -m venv venv
80
+ source venv/bin/activate
81
+ ```
82
+
83
+ Using `conda`:
84
+ ```commandline
85
+ conda create --name nyc-taxi-env python=3.8 -y
86
+ conda activate nyc-taxi-env
87
+ ```
88
+
89
+ ### 3. Install dependencies:
90
+
91
+ Install all the required dependencies listed in `requirements.txt` file using `pip`:
92
+ ```commandline
93
+ pip install -r requirements.txt
94
+ ```
95
+
96
+ ### 4. Install Jupyter Notebook (Optional):
97
+
98
+ If Jupyter Notebook is not already installed, you can install it using `pip`:
99
+
100
+ ```commandline
101
+ pip install notebook
102
+ ```
103
+
104
+ Alternatively, if you are using `conda`, you can install Jupyter Notebook using the following command:
105
+
106
+ ```commandline
107
+ conda install -c conda-forge notebook
108
+ ```
109
+
110
+ ### 5: Start Jupyter Notebook
111
+ Launch Jupyter Notebook to execute the project code:
112
+
113
+ ```commandline
114
+ jupyter notebook
115
+ ```
116
+ A browser window should open, displaying the Jupyter Notebook interface.
117
+ If it does not open automatically, copy and paste the link shown in the terminal into your web browser.
118
+
119
+ You are now ready to run the project code in the Jupyter Notebook.
120
+
121
+ ---
122
+
123
+ ## Development Setup
124
+
125
+ ### 1: Code Formatting
126
+ We use Black for code formatting. To set up:
127
+
128
+ 1. Install black and pre-commit:
129
+ ```bash
130
+ pip install black pre-commit
131
+ ```
132
+
133
+ 2. Run pre-commit hooks:
134
+ ```bash
135
+ pre-commit install
136
+ ```
137
+
138
+ 3. Run Black manually:
139
+ ```bash
140
+ black .
141
+ ```
142
+
143
+ 4. Configure VS Code (optional):
144
+ ```json
145
+ {
146
+ "python.formatting.provider": "black",
147
+ "editor.formatOnSave": true
148
+ }
149
+ ```
150
+
151
+ Refer to the [Black documentation](https://black.readthedocs.io/en/stable/) for more information. Reference from the article [here](https://dev.to/emmo00/how-to-setup-black-and-pre-commit-in-python-for-auto-text-formatting-on-commit-4kka)
152
+
153
+ ---
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data analytics libraries
2
+ pandas
3
+ matplotlib
4
+ numpy
5
+ seaborn
6
+ statsmodels
7
+
8
+ # Machine learning and deep learning libraries
9
+ scikit-learn
10
+ tensorflow
11
+ keras
12
+ torch
13
+
14
+ # Llm
15
+ llama-index
16
+ llama-index-llms-openai
17
+ llama-index-llms-nvidia
18
+ llama-index-llms-openai-like
19
+
20
+ # Others
21
+ tqdm
22
+ black
23
+ pre-commit
24
+ streamlit
25
+ plotly
26
+ pyod
src/app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.express as px
5
+ import plotly.graph_objs as go
6
+ from sklearn.preprocessing import StandardScaler
7
+ from pyod.models.iforest import IForest
8
+ from datetime import datetime, timedelta
9
+
10
+
11
+ class NYCTaxiAnomalyDetector:
12
+ def __init__(self, data):
13
+ self.data = data.copy()
14
+ self.scaler = StandardScaler()
15
+
16
+ def filter_by_date_range(self, start_date, end_date):
17
+ """
18
+ Filter data by specified date range
19
+
20
+ :param start_date: Start date of the range
21
+ :param end_date: End date of the range
22
+ :return: Filtered DataFrame
23
+ """
24
+ # Ensure date column is datetime
25
+ if not pd.api.types.is_datetime64_any_dtype(self.data["date"]):
26
+ self.data["date"] = pd.to_datetime(self.data["date"])
27
+
28
+ # Filter data
29
+ filtered_data = self.data[
30
+ (self.data["date"] >= start_date) & (self.data["date"] <= end_date)
31
+ ]
32
+
33
+ return filtered_data
34
+
35
+ def preprocess_data(self, data, column):
36
+ """
37
+ Preprocess data for anomaly detection
38
+
39
+ :param data: Filtered DataFrame
40
+ :param column: Column to detect anomalies in
41
+ :return: Scaled data and original index
42
+ """
43
+ # Ensure the column is numeric
44
+ data[column] = pd.to_numeric(data[column], errors="coerce")
45
+
46
+ # Remove NaN values
47
+ clean_data = data[column].dropna()
48
+
49
+ # Scale the data
50
+ scaled_data = self.scaler.fit_transform(clean_data.values.reshape(-1, 1))
51
+
52
+ return scaled_data, clean_data.index
53
+
54
+ def detect_anomalies(self, data, column, contamination=0.05):
55
+ """
56
+ Detect anomalies using Isolation Forest
57
+
58
+ :param data: Filtered DataFrame
59
+ :param column: Column to detect anomalies in
60
+ :param contamination: Expected proportion of outliers
61
+ :return: DataFrame with anomaly detection results
62
+ """
63
+ # Preprocess data
64
+ scaled_data, original_index = self.preprocess_data(data, column)
65
+
66
+ # Apply Isolation Forest
67
+ clf = IForest(contamination=contamination, random_state=42)
68
+ y_pred = clf.fit_predict(scaled_data)
69
+
70
+ # Create results DataFrame
71
+ anomaly_results = pd.DataFrame(
72
+ {
73
+ "date": original_index,
74
+ "value": data.loc[original_index, column],
75
+ "is_anomaly": y_pred == 1,
76
+ }
77
+ )
78
+
79
+ return anomaly_results
80
+
81
+
82
+ class AIContextGenerator:
83
+ def generate_context(self, anomaly_date):
84
+ """
85
+ Generate potential context for the anomaly
86
+
87
+ :param anomaly_date: Date of the anomaly
88
+ :return: List of contextual insights
89
+ """
90
+ # Mock contextual insights - replace with actual data sources
91
+ contexts = [
92
+ {
93
+ "type": "Weather",
94
+ "description": f"Weather conditions on {anomaly_date.date()}",
95
+ "severity": "High",
96
+ },
97
+ {
98
+ "type": "Event",
99
+ "description": f"City events around {anomaly_date.date()}",
100
+ "severity": "Medium",
101
+ },
102
+ {
103
+ "type": "Economic",
104
+ "description": f"Economic factors on {anomaly_date.date()}",
105
+ "severity": "Low",
106
+ },
107
+ ]
108
+ return contexts
109
+
110
+
111
+ def load_nyc_taxi_data():
112
+ """
113
+ Load and preprocess NYC Taxi dataset
114
+
115
+ :return: DataFrame with synthetic taxi traffic data
116
+ """
117
+ # Synthetic data generation
118
+ dates = pd.date_range(start="2023-01-01", end="2023-12-31", freq="D")
119
+ base_traffic = np.random.normal(5000, 500, len(dates))
120
+
121
+ # Introduce some anomalies
122
+ base_traffic[50] = 10000 # Extreme spike
123
+ base_traffic[200] = 500 # Extreme drop
124
+ base_traffic[300] = 12000 # Another spike
125
+
126
+ df = pd.DataFrame({"date": dates, "daily_traffic": base_traffic})
127
+
128
+ return df
129
+
130
+
131
+ def main():
132
+ st.set_page_config(
133
+ page_title="NYC Taxi Traffic Anomaly Detection", page_icon="🚕", layout="wide"
134
+ )
135
+
136
+ st.title("🚕 NYC Taxi Traffic Anomaly Detection")
137
+
138
+ # Load Data
139
+ taxi_data = load_nyc_taxi_data()
140
+
141
+ # Sidebar for Configuration
142
+ st.sidebar.header("Anomaly Detection Settings")
143
+
144
+ # Date Range Selection
145
+ st.sidebar.subheader("Date Range")
146
+ min_date = taxi_data["date"].min().date()
147
+ max_date = taxi_data["date"].max().date()
148
+
149
+ col1, col2 = st.sidebar.columns(2)
150
+ with col1:
151
+ start_date = st.date_input(
152
+ "Start Date", min_value=min_date, max_value=max_date, value=min_date
153
+ )
154
+
155
+ with col2:
156
+ end_date = st.date_input(
157
+ "End Date", min_value=min_date, max_value=max_date, value=max_date
158
+ )
159
+
160
+ # Anomaly Sensitivity
161
+ anomaly_threshold = st.sidebar.slider(
162
+ "Anomaly Sensitivity",
163
+ min_value=0.01,
164
+ max_value=0.1,
165
+ value=0.05,
166
+ step=0.01,
167
+ help="Lower values detect fewer but more extreme anomalies",
168
+ )
169
+
170
+ # Instantiate Detector
171
+ detector = NYCTaxiAnomalyDetector(taxi_data)
172
+
173
+ # Filter Data by Date Range
174
+ filtered_data = detector.filter_by_date_range(
175
+ pd.to_datetime(start_date), pd.to_datetime(end_date)
176
+ )
177
+
178
+ # Detect Anomalies
179
+ anomalies = detector.detect_anomalies(
180
+ filtered_data, "daily_traffic", contamination=anomaly_threshold
181
+ )
182
+
183
+ # Visualization
184
+ st.header("Daily Taxi Traffic Trend")
185
+ fig = px.line(
186
+ filtered_data,
187
+ x="date",
188
+ y="daily_traffic",
189
+ title=f"NYC Taxi Daily Traffic ({start_date} to {end_date})",
190
+ labels={"daily_traffic": "Number of Taxi Rides"},
191
+ )
192
+
193
+ # Highlight Anomalies
194
+ anomaly_points = filtered_data[anomalies["is_anomaly"]]
195
+ fig.add_trace(
196
+ go.Scatter(
197
+ x=anomaly_points["date"],
198
+ y=anomaly_points["daily_traffic"],
199
+ mode="markers",
200
+ name="Anomalies",
201
+ marker=dict(color="red", size=10, symbol="star"),
202
+ )
203
+ )
204
+
205
+ st.plotly_chart(fig, use_container_width=True)
206
+
207
+ # Anomaly Details
208
+ st.header("Anomaly Insights")
209
+
210
+ if not anomaly_points.empty:
211
+ context_generator = AIContextGenerator()
212
+
213
+ for _, anomaly in anomaly_points.iterrows():
214
+ st.subheader(f"Anomaly on {anomaly['date'].date()}")
215
+
216
+ col1, col2 = st.columns(2)
217
+
218
+ with col1:
219
+ st.metric("Taxi Rides", f"{anomaly['daily_traffic']:.0f}")
220
+
221
+ with col2:
222
+ contexts = context_generator.generate_context(anomaly["date"])
223
+ st.write("### Potential Context")
224
+ for context in contexts:
225
+ st.markdown(
226
+ f"""
227
+ - **{context['type']}**: {context['description']}
228
+ (Severity: {context['severity']})
229
+ """
230
+ )
231
+ else:
232
+ st.info("No significant anomalies detected with current settings.")
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()
src/config/llm/nvidia-llama-3.1-nemotron-70b-instruct.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PROVIDER: nvidia
2
+ BASE_URL: https://integrate.api.nvidia.com/v1
3
+ MODEL: nvidia/llama-3.1-nemotron-70b-instruct
4
+ TEMPERATURE: 0
src/config/llm/openai-gpt-3.5-turbo.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PROVIDER: openai
2
+ BASE_URL: default
3
+ MODEL: gpt-3.5-turbo
4
+ TEMPERATURE: 0
src/config/llm/openai-gpt-4o-mini.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PROVIDER: openai
2
+ BASE_URL: default
3
+ MODEL: gpt-4o-mini
4
+ TEMPERATURE: 0
src/llm/base_llm_provider.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base class for LLM providers"""
2
+
3
+ from abc import abstractmethod
4
+ from typing import Dict, Optional
5
+
6
+
7
+ class BaseLLMProvider:
8
+ @abstractmethod
9
+ def __init__(self):
10
+ """LLM provider initialization"""
11
+ raise NotImplementedError
12
+
13
+ @abstractmethod
14
+ def complete(self, prompt: str = "") -> str:
15
+ """LLM chat completion implementation by each provider"""
16
+ raise NotImplementedError
src/llm/enums.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ OPENAI_LLM = "openai"
2
+ NVIDIA_LLM = "nvidia"
3
+ DEFAULT_LLM_API_BASE = "default"
src/llm/llm.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ from src.llm.enums import OPENAI_LLM, NVIDIA_LLM
4
+ from src.llm.base_llm_provider import BaseLLMProvider
5
+ from src.llm.openai_llm import OpenAILLM
6
+ from src.llm.nvidia_llm import NvidiaLLM
7
+
8
+
9
+ def get_llm(config_file_path: str = "config.yaml") -> BaseLLMProvider:
10
+ """
11
+ Initiates LLM client from config file
12
+ """
13
+
14
+ # load config
15
+ with open(config_file_path, "r") as f:
16
+ config = yaml.safe_load(f)
17
+
18
+ # init & return llm
19
+ if config["PROVIDER"] == OPENAI_LLM:
20
+ return OpenAILLM(
21
+ model=config["MODEL"],
22
+ temperature=config["TEMPERATURE"],
23
+ base_url=config["BASE_URL"],
24
+ )
25
+ elif config["PROVIDER"] == NVIDIA_LLM:
26
+ return NvidiaLLM(
27
+ model=config["MODEL"],
28
+ temperature=config["TEMPERATURE"],
29
+ base_url=config["BASE_URL"],
30
+ )
31
+ else:
32
+ raise ValueError(config["MODEL"])
src/llm/nvidia_llm.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NVIDIA LLM Implementation"""
2
+
3
+ from llama_index.llms.nvidia import NVIDIA
4
+
5
+ from src.llm.base_llm_provider import BaseLLMProvider
6
+ from src.llm.enums import DEFAULT_LLM_API_BASE
7
+
8
+
9
+ class NvidiaLLM(BaseLLMProvider):
10
+ def __init__(
11
+ self,
12
+ model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
13
+ temperature: float = 0.0,
14
+ base_url: str = "https://integrate.api.nvidia.com/v1",
15
+ ):
16
+ """Initiate NVIDIA client"""
17
+
18
+ if base_url == DEFAULT_LLM_API_BASE:
19
+ self._client = NVIDIA(
20
+ model=model,
21
+ temperature=temperature,
22
+ )
23
+ else:
24
+ self._client = NVIDIA(
25
+ model=model, temperature=temperature, base_url=base_url
26
+ )
27
+
28
+ def complete(self, prompt: str = "") -> str:
29
+ return str(self._client.complete(prompt))
src/llm/openai_llm.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI LLM Implementation"""
2
+
3
+ from llama_index.llms.openai import OpenAI
4
+
5
+ from src.llm.base_llm_provider import BaseLLMProvider
6
+ from src.llm.enums import DEFAULT_LLM_API_BASE
7
+
8
+
9
+ class OpenAILLM(BaseLLMProvider):
10
+ def __init__(
11
+ self,
12
+ model: str = "gpt-4o-mini",
13
+ temperature: float = 0.0,
14
+ base_url: str = DEFAULT_LLM_API_BASE,
15
+ ):
16
+ """Initiate OpenAI client"""
17
+
18
+ if base_url == DEFAULT_LLM_API_BASE:
19
+ self._client = OpenAI(
20
+ model=model,
21
+ temperature=temperature,
22
+ )
23
+ else:
24
+ self._client = OpenAI(
25
+ model=model, temperature=temperature, base_url=base_url
26
+ )
27
+
28
+ def complete(self, prompt: str = "") -> str:
29
+ return str(self._client.complete(prompt))