Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from enum import IntEnum | |
| import yaml | |
| from typing import Dict, Optional, List | |
| from pydantic import BaseModel, ValidationError | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.utils import EntryNotFoundError | |
| from openai import OpenAI | |
| class OAuthProvider(IntEnum): | |
| NONE = 0 | |
| GOOGLE = 1 | |
| class User: | |
| oauth: OAuthProvider | |
| username: str | |
| permissions_id: str | |
| class PileConfig(BaseModel): | |
| file2persona: Dict[str, str] | |
| file2prefix: Dict[str, str] | |
| persona2system: Dict[str, str] | |
| prompt: str | |
| class PermissionsConfig(BaseModel): | |
| google_domains: Optional[List[str]] = None | |
| class InferenceConfig(BaseModel): | |
| chat_template: str | |
| permissions: Optional[PermissionsConfig] = None | |
| class RepoConfig(BaseModel): | |
| name: str | |
| class ModelConfig(BaseModel): | |
| pile: PileConfig | |
| inference: InferenceConfig | |
| repo: RepoConfig | |
| def from_yaml(cls, yaml_file = "datasets/config.yaml"): | |
| with open(yaml_file, 'r') as file: | |
| data = yaml.safe_load(file) | |
| try: | |
| return cls(**data) | |
| except ValidationError as e: | |
| raise e | |
| class Client: | |
| def __init__(self, api_url, api_key, personas = {}): | |
| self.api_url = api_url | |
| self.api_key = api_key | |
| self.input_personas = personas | |
| self.init_all() | |
| def init_all(self): | |
| self.init_client() | |
| self.get_metadata() | |
| self.get_personas() | |
| def init_client(self): | |
| self.openai = OpenAI( | |
| base_url=f"{self.api_url}/v1", | |
| api_key=self.api_key, | |
| ) | |
| def get_metadata(self): | |
| models = self.openai.models.list() | |
| vllm_model_name = models.data[0].id | |
| model_name, *suffix = vllm_model_name.split("@") | |
| revision = dict(enumerate(suffix)).get(0, None) | |
| self.vllm_model_name = vllm_model_name | |
| self.model_name = model_name | |
| self.revision = revision | |
| def get_personas(self): | |
| personas = {} | |
| if self.revision is not None: | |
| try: | |
| config_path = hf_hub_download(self.model_name, "config.yaml", | |
| subfolder="datasets", | |
| revision=self.revision) | |
| self.config = ModelConfig.from_yaml(config_path) | |
| personas = self.config.pile.persona2system | |
| except EntryNotFoundError: | |
| pass | |
| personas["vanilla"] = None | |
| self.personas = self.input_personas | personas |