Spaces:
Sleeping
Sleeping
import re | |
import iso639 | |
from tqdm import tqdm | |
from functools import lru_cache | |
from huggingface_hub import list_models, hf_hub_download, snapshot_download | |
from collections import defaultdict | |
ARCH_TO_INCLUDE = ['parakeet', 'conformer', 'fastconformer'] | |
def get_models_list(sort_alphabetically=True): | |
lang_models = defaultdict(list) | |
models_langs = dict() | |
models = list_models(author='nvidia', | |
task="automatic-speech-recognition", | |
sort='downloads', | |
cardData=True) | |
for model in models: | |
model_id = model.modelId.replace('nvidia/', '') | |
if not any([arch in model_id for arch in ARCH_TO_INCLUDE]): | |
continue | |
language_tags = model.cardData.get('language', ['Unknown']) | |
language_tags = [language_tags] if isinstance(language_tags, str) else language_tags | |
lang_names = [] | |
for language in language_tags: | |
try: | |
lang_name = iso639.Language.match(language).name | |
lang_names.append(lang_name) | |
except: | |
lang_name = 'Unknown' | |
lang_models[lang_name].append(model_id) | |
if sort_alphabetically: | |
lang_names = sorted(lang_names) | |
models_langs[model_id] = lang_names | |
lang_models.pop('Unknown', None) | |
if sort_alphabetically: | |
lang_models = dict( | |
sorted(lang_models.items()) | |
) | |
return lang_models, models_langs | |
def extract_section_from_readme(content): | |
# Adjust start marker to capture text after the badges section | |
start_marker = r"\[\!\[Model architecture\]\(https://img\.shields\.io.*?\)\].*?\n\n" | |
end_marker = r"##" | |
# Use regex to capture content between start_marker and end_marker | |
match = re.search(f"{start_marker}(.*?){end_marker}", content, re.DOTALL) | |
if not match: | |
match = re.search(r"# .+?\n\n(.*?)(?=\n## )", content, re.DOTALL) | |
if match: | |
# Extract the main content | |
section = match.group(1).strip() | |
# Remove any sentence starting with "See" or containing the word "RIVA" | |
section = re.sub(r"(See.*$|.*\bRiva\b.*$)", "", section, flags=re.MULTILINE).strip() | |
# Remove numbers in square brackets (e.g., [1], [2]) | |
section = re.sub(r"\[\d+\]", "", section).strip() | |
return section | |
else: | |
return None | |
def get_model_description(model_name): | |
if 'nvidia/' not in model_name: | |
model_name = 'nvidia/' + model_name | |
readme_path = hf_hub_download(repo_id=model_name, filename="README.md") | |
with open(readme_path, "r", encoding="utf-8") as file: | |
readme_content = file.read() | |
extracted_section = extract_section_from_readme(readme_content) | |
more_info = f"See more on the selected model on [{model_name}](https://huggingface.co/{model_name})." | |
return extracted_section, more_info | |
def predownload_models(models, top=None): | |
if top: | |
models = models[:top] | |
for model_name in tqdm(models): | |
snapshot_download('nvidia/' + model_name) |