grouped-sampling-demo / supported_models.py
yonikremer's picture
added black listed models
af40811
raw
history blame
4.87 kB
from typing import Generator, Set, Union, List
import requests
from bs4 import BeautifulSoup, Tag, NavigableString, PageElement
from concurrent.futures import ThreadPoolExecutor, as_completed
SUPPORTED_MODEL_NAME_PAGES_FORMAT = "https://huggingface.co/models?pipeline_tag=text-generation&library=pytorch"
MAX_WORKERS = 10
BLACKLISTED_MODEL_NAMES = {
"ykilcher/gpt-4chan",
"bigscience/mt0-xxl",
"bigscience/mt0-xl",
"bigscience/mt0-large",
"bigscience/mt0-base",
"bigscience/mt0-small",
}
MIN_NUMBER_OF_DOWNLOADS = 100
MIN_NUMBER_OF_LIKES = 20
def get_model_name(model_card: Tag) -> str:
h4_class = "text-md truncate font-mono text-black dark:group-hover:text-yellow-500 group-hover:text-indigo-600"
h4_tag = model_card.find("h4", class_=h4_class)
return h4_tag.text
def is_a_number(s: PageElement) -> bool:
s = s.text.strip().lower().replace("k", "").replace("m", "").replace(",", "").replace(".", "").replace("b", "")
try:
float(s)
return True
except ValueError:
return False
def get_numeric_contents(model_card):
div: Union[Tag | NavigableString] = model_card.find(
"div",
class_="mr-1 flex items-center overflow-hidden whitespace-nowrap text-sm leading-tight text-gray-400",
recursive=True
)
contents: List[PageElement] = div.contents
contents_without_tags: List[PageElement] = [content for content in contents if not isinstance(content, Tag)]
number_contents: List[PageElement] = [content for content in contents_without_tags if is_a_number(content)]
return number_contents
def convert_to_int(element: PageElement) -> int:
element_str = element.text.strip().lower()
if element_str.endswith("k"):
return int(float(element_str[:-1]) * 1_000)
elif element_str.endswith("m"):
return int(float(element_str[:-1]) * 1_000_000)
elif element_str.endswith("b"):
return int(float(element_str[:-1]) * 1_000_000_000)
else:
return int(element_str)
def get_page(page_index: int):
curr_page_url = f"{SUPPORTED_MODEL_NAME_PAGES_FORMAT}&p={page_index}"
response = requests.get(curr_page_url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
return soup
return None
def card_filter(
model_card: Tag,
model_name: str,
min_number_of_downloads: int,
min_number_of_likes: int,
) -> bool:
if model_name in BLACKLISTED_MODEL_NAMES:
return False
numeric_contents = get_numeric_contents(model_card)
if len(numeric_contents) < 2:
# If the model card doesn't have at least 2 numeric contents,
# It means that he doesn't have any downloads/likes, so it's not a valid model card.
return False
number_of_downloads = convert_to_int(numeric_contents[0])
if number_of_downloads < min_number_of_downloads:
return False
number_of_likes = convert_to_int(numeric_contents[1])
if number_of_likes < min_number_of_likes:
return False
return True
def get_model_names(
soup: BeautifulSoup,
min_number_of_downloads: int,
min_number_of_likes: int,
) -> Generator[str, None, None]:
model_cards: List[Tag] = soup.find_all("article", class_="overview-card-wrapper group", recursive=True)
for model_card in model_cards:
model_name = get_model_name(model_card)
if card_filter(
model_card=model_card,
model_name=model_name,
min_number_of_downloads=min_number_of_downloads,
min_number_of_likes=min_number_of_likes
):
yield model_name
def generate_supported_model_names(
min_number_of_downloads: int,
min_number_of_likes: int,
) -> Generator[str, None, None]:
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
future_to_index = {executor.submit(get_page, index): index for index in range(100)}
for future in as_completed(future_to_index):
soup = future.result()
if soup:
yield from get_model_names(
soup=soup,
min_number_of_downloads=min_number_of_downloads,
min_number_of_likes=min_number_of_likes,
)
def get_supported_model_names(
min_number_of_downloads: int = MIN_NUMBER_OF_DOWNLOADS,
min_number_of_likes: int = MIN_NUMBER_OF_LIKES,
) -> Set[str]:
return set(generate_supported_model_names(
min_number_of_downloads=min_number_of_downloads,
min_number_of_likes=min_number_of_likes,
))
if __name__ == "__main__":
supported_model_names = get_supported_model_names()
print(f"Number of supported model names: {len(supported_model_names)}")
print(f"Supported model names: {supported_model_names}")