Spaces:
Sleeping
Sleeping
File size: 1,547 Bytes
ac20456 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from datasets import load_dataset
from dotenv import load_dotenv
from datasets import DatasetDict
import os
import pandas as pd
from typing import Optional
load_dotenv()
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
HF_TOKEN = os.environ["HF_TOKEN"]
def load_data_hf(repo_name: str, data_files: str, is_public: bool) -> DatasetDict:
if is_public:
dataset = load_dataset(repo_name, split="train")
else:
dataset = load_dataset(repo_name, token=True, data_files=data_files)
return dataset
def load_scores(category: str) -> pd.DataFrame | None:
repository = os.environ.get("DATA_REPO")
data_file = None
match category:
case "popularity":
data_file = "computed/popularity/popularity_scores.csv"
case "seasonality":
data_file = "computed/seasonality/seasonality_scores.csv"
case "emissions":
data_file = "computed/emissions/emissions_merged.csv"
case _:
logger.info(f"Invalid category: {category}")
if data_file: # only for valid categories
data = load_data_hf(repository, data_file, is_public=False)
df = pd.DataFrame(data["train"][:])
return df
return None
def load_places(data_file: str) -> pd.DataFrame | None:
repository = os.environ.get("DATA_REPO")
if data_file:
data = load_data_hf(repository, data_file, is_public=False)
df = pd.DataFrame(data["train"][:])
return df
return None
|