yonikremer commited on
Commit
15bf463
·
1 Parent(s): 0461bfe

added a scraper for the supported model names

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -2
  2. supported_models.py +48 -0
requirements.txt CHANGED
@@ -2,5 +2,6 @@ grouped-sampling>=1.0.4
2
  streamlit==1.17.0
3
  torch>1.12.1
4
  transformers
5
- pip==23.0
6
- hatchling
 
 
2
  streamlit==1.17.0
3
  torch>1.12.1
4
  transformers
5
+ hatchling
6
+ beautifulsoup4
7
+ urllib3
supported_models.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Generator
2
+
3
+ from bs4 import BeautifulSoup, Tag
4
+ import urllib3
5
+
6
+ SUPPORTED_MODEL_NAME_PAGES_FORMAT: str = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch"
7
+
8
+
9
+ def get_model_name(model_card: Tag) -> str:
10
+ """
11
+ Gets the model name from the model card.
12
+ :param model_card: The model card to get the model name from.
13
+ :return: The model name.
14
+ """
15
+ h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600"
16
+ h4_tag: Tag = model_card.find("h4", class_=h4_class)
17
+ return h4_tag.text
18
+
19
+
20
+ def get_soups() -> Generator[BeautifulSoup, None, None]:
21
+ """
22
+ Gets the pages to scrape.
23
+ :return: A list of the pages to scrape.
24
+ """
25
+ curr_page_index = 0
26
+ while True:
27
+ curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={curr_page_index}"
28
+ request = urllib3.PoolManager().request("GET", curr_page_url)
29
+ if request.status != 200:
30
+ return
31
+ yield BeautifulSoup(request.data, "html.parser")
32
+ curr_page_index += 1
33
+
34
+
35
+ def get_supported_model_names() -> Generator[str, None, None]:
36
+ """
37
+ Scrapes the supported model names from the hugging face website.
38
+ :return: A list of the supported model names.
39
+ """
40
+ for soup in get_soups():
41
+ model_cards: List[Tag] = soup.find_all("article", class_="overview-card-wrapper group", recursive=True)
42
+ for model_card in model_cards:
43
+ yield get_model_name(model_card)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ for model_name in get_supported_model_names():
48
+ print(model_name)