from dataclasses import dataclass from transformers import AutoConfig @dataclass class ModelSizeChecker: model: str precision: str model_size_in_b: float def get_precision_factor(self): if self.precision in ["float16", "bfloat16"]: return 1 elif self.precision == "8bit": return 2 elif self.precision == "4bit": return 4 elif self.precision == "GPTQ": config = AutoConfig.from_pretrained(self.model) num_bits = int(config.quantization_config["bits"]) bits_to_precision_factor = {2: 8, 3: 6, 4: 4, 8: 2} return bits_to_precision_factor.get(num_bits, 1) else: raise Exception(f"Unknown precision {self.precision}.") def can_evaluate(self): precision_factor = self.get_precision_factor() return self.model_size_in_b <= 140 * precision_factor