Commit
·
cd990a6
1
Parent(s):
41b6658
stopped downloading the models when the servers start
Browse files- app.py +0 -2
- on_server_start.py +0 -37
- tests.py +0 -7
app.py
CHANGED
@@ -9,9 +9,7 @@ from torch.cuda import CudaError
|
|
9 |
|
10 |
from available_models import AVAILABLE_MODELS
|
11 |
from hanlde_form_submit import on_form_submit
|
12 |
-
from on_server_start import main as on_server_start_main
|
13 |
|
14 |
-
# on_server_start_main()
|
15 |
|
16 |
st.title("Grouped Sampling Demo")
|
17 |
|
|
|
9 |
|
10 |
from available_models import AVAILABLE_MODELS
|
11 |
from hanlde_form_submit import on_form_submit
|
|
|
12 |
|
|
|
13 |
|
14 |
st.title("Grouped Sampling Demo")
|
15 |
|
on_server_start.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A script that is run when the server starts.
|
3 |
-
"""
|
4 |
-
from concurrent.futures import ThreadPoolExecutor
|
5 |
-
|
6 |
-
from transformers import logging as transformers_logging
|
7 |
-
from huggingface_hub import logging as huggingface_hub_logging
|
8 |
-
|
9 |
-
from available_models import AVAILABLE_MODELS
|
10 |
-
from download_repo import download_pytorch_model
|
11 |
-
|
12 |
-
|
13 |
-
def disable_progress_bar():
|
14 |
-
"""
|
15 |
-
Disables the progress bar when downloading models.
|
16 |
-
"""
|
17 |
-
transformers_logging.disable_progress_bar()
|
18 |
-
huggingface_hub_logging.disable_propagation()
|
19 |
-
|
20 |
-
|
21 |
-
def download_useful_models():
|
22 |
-
"""
|
23 |
-
Downloads the models that are useful for this project.
|
24 |
-
So that the user doesn't have to wait for the models to download when they first use the app.
|
25 |
-
"""
|
26 |
-
print("Downloading useful models. It might take a while...")
|
27 |
-
with ThreadPoolExecutor() as executor:
|
28 |
-
executor.map(download_pytorch_model, AVAILABLE_MODELS)
|
29 |
-
|
30 |
-
|
31 |
-
def main():
|
32 |
-
# disable_progress_bar()
|
33 |
-
download_useful_models()
|
34 |
-
|
35 |
-
|
36 |
-
if __name__ == "__main__":
|
37 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests.py
CHANGED
@@ -3,18 +3,11 @@ import os
|
|
3 |
import pytest as pytest
|
4 |
from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException
|
5 |
|
6 |
-
from on_server_start import download_useful_models
|
7 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
8 |
|
9 |
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
|
10 |
|
11 |
|
12 |
-
def test_on_server_start():
|
13 |
-
download_useful_models()
|
14 |
-
assert os.path.exists(HUGGING_FACE_CACHE_DIR)
|
15 |
-
assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
|
16 |
-
|
17 |
-
|
18 |
def test_on_form_submit():
|
19 |
model_name = "gpt2"
|
20 |
output_length = 10
|
|
|
3 |
import pytest as pytest
|
4 |
from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException
|
5 |
|
|
|
6 |
from hanlde_form_submit import create_pipeline, on_form_submit
|
7 |
|
8 |
HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def test_on_form_submit():
|
12 |
model_name = "gpt2"
|
13 |
output_length = 10
|