|
from typing import Dict, List, Any |
|
import torch |
|
from transformers import BitsAndBytesConfig, pipeline |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
print(data) |
|
inputs = data.pop("inputs", data) |
|
image = data.pop("image", None) |
|
prompt = data.pop("prompt", None) |
|
|
|
|
|
|
|
return {"image": image, "prompt": prompt} |
|
|