Spaces:
Running
on
Zero
Running
on
Zero
import replicate | |
from PIL import Image | |
import requests | |
import io | |
import os | |
import base64 | |
Replicate_MODEl_NAME_MAP = { | |
"SDXL": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc", | |
"SD-v3.0": "stability-ai/stable-diffusion-3", | |
"SD-v2.1": "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4", | |
"SD-v1.5": "stability-ai/stable-diffusion:b3d14e1cd1f9470bbb0bb68cac48e5f483e5be309551992cc33dc30654a82bb7", | |
"SDXL-Lightning": "bytedance/sdxl-lightning-4step:5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f", | |
"Kandinsky-v2.0": "ai-forever/kandinsky-2:3c6374e7a9a17e01afe306a5218cc67de55b19ea536466d6ea2602cfecea40a9", | |
"Kandinsky-v2.2": "ai-forever/kandinsky-2.2:ad9d7879fbffa2874e1d909d1d37d9bc682889cc65b31f7bb00d2362619f194a", | |
"Proteus-v0.2": "lucataco/proteus-v0.2:06775cd262843edbde5abab958abdbb65a0a6b58ca301c9fd78fa55c775fc019", | |
"Playground-v2.0": "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8", | |
"Playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24", | |
"Dreamshaper-xl-turbo": "lucataco/dreamshaper-xl-turbo:0a1710e0187b01a255302738ca0158ff02a22f4638679533e111082f9dd1b615", | |
"SDXL-Deepcache": "lucataco/sdxl-deepcache:eaf678fb34006669e9a3c6dd5971e2279bf20ee0adeced464d7b6d95de16dc93", | |
"Openjourney-v4": "prompthero/openjourney:ad59ca21177f9e217b9075e7300cf6e14f7e5b4505b87b9689dbd866e9768969", | |
"LCM": "fofr/latent-consistency-model:683d19dc312f7a9f0428b04429a9ccefd28dbf7785fef083ad5cf991b65f406f", | |
"Realvisxl-v3.0": "fofr/realvisxl-v3:33279060bbbb8858700eb2146350a98d96ef334fcf817f37eb05915e1534aa1c", | |
"Realvisxl-v2.0": "lucataco/realvisxl-v2.0:7d6a2f9c4754477b12c14ed2a58f89bb85128edcdd581d24ce58b6926029de08", | |
"Pixart-Sigma": "cjwbw/pixart-sigma:5a54352c99d9fef467986bc8f3a20205e8712cbd3df1cbae4975d6254c902de1", | |
"SSD-1b": "lucataco/ssd-1b:b19e3639452c59ce8295b82aba70a231404cb062f2eb580ea894b31e8ce5bbb6", | |
"Open-Dalle-v1.1": "lucataco/open-dalle-v1.1:1c7d4c8dec39c7306df7794b28419078cb9d18b9213ab1c21fdc46a1deca0144", | |
"Deepfloyd-IF": "andreasjansson/deepfloyd-if:fb84d659df149f4515c351e394d22222a94144aa1403870c36025c8b28846c8d", | |
} | |
class ReplicateModel(): | |
def __init__(self, model_name, model_type): | |
self.model_name = model_name | |
self.model_type = model_type | |
# os.environ['FAL_KEY'] = os.environ['FalAPI'] | |
def __call__(self, *args, **kwargs): | |
# def decode_data_url(data_url): | |
# # Find the start of the Base64 encoded data | |
# base64_start = data_url.find(",") + 1 | |
# if base64_start == 0: | |
# raise ValueError("Invalid data URL provided") | |
# # Extract the Base64 encoded data | |
# base64_string = data_url[base64_start:] | |
# # Decode the Base64 string | |
# decoded_bytes = base64.b64decode(base64_string) | |
# return decoded_bytes | |
if self.model_type == "text2image": | |
assert "prompt" in kwargs, "prompt is required for text2image model" | |
output = replicate.run( | |
f"{Replicate_MODEl_NAME_MAP[self.model_name]}", | |
input={ | |
"width": 512, | |
"height": 512, | |
"prompt": kwargs["prompt"] | |
}, | |
) | |
if 'Openjourney' in self.model_name: | |
for item in output: | |
result_url = item | |
break | |
elif isinstance(output, list): | |
result_url = output[0] | |
else: | |
result_url = output | |
print(result_url) | |
response = requests.get(result_url) | |
result = Image.open(io.BytesIO(response.content)) | |
# fal_client.submit( | |
# f"fal-ai/{FAL_MODEl_NAME_MAP[self.model_name]}", | |
# arguments={ | |
# "prompt": kwargs["prompt"] | |
# }, | |
# ) | |
# for event in handler.iter_events(with_logs=True): | |
# if isinstance(event, fal_client.InProgress): | |
# print('Request in progress') | |
# print(event.logs) | |
# result = handler.get() | |
# print(result) | |
# result_url = result['images'][0]['url'] | |
# if self.model_name in ["SDXLTurbo", "LCM(v1.5/XL)"]: | |
# result_url = io.BytesIO(decode_data_url(result_url)) | |
# result = Image.open(result_url) | |
# else: | |
# response = requests.get(result_url) | |
# result = Image.open(io.BytesIO(response.content)) | |
return result | |
# elif self.model_type == "image2image": | |
# raise NotImplementedError("image2image model is not implemented yet") | |
# # assert "image" in kwargs or "image_url" in kwargs, "image or image_url is required for image2image model" | |
# # if "image" in kwargs: | |
# # image_url = None | |
# # pass | |
# # handler = fal_client.submit( | |
# # f"fal-ai/{self.model_name}", | |
# # arguments={ | |
# # "image_url": image_url | |
# # }, | |
# # ) | |
# # | |
# # for event in handler.iter_events(): | |
# # if isinstance(event, fal_client.InProgress): | |
# # print('Request in progress') | |
# # print(event.logs) | |
# # | |
# # result = handler.get() | |
# # return result | |
# elif self.model_type == "text2video": | |
# assert "prompt" in kwargs, "prompt is required for text2video model" | |
# if self.model_name == 'AnimateDiff': | |
# fal_model_name = 'fast-animatediff/text-to-video' | |
# elif self.model_name == 'AnimateDiffTurbo': | |
# fal_model_name = 'fast-animatediff/turbo/text-to-video' | |
# else: | |
# raise NotImplementedError(f"text2video model of {self.model_name} in fal is not implemented yet") | |
# handler = fal_client.submit( | |
# f"fal-ai/{fal_model_name}", | |
# arguments={ | |
# "prompt": kwargs["prompt"] | |
# }, | |
# ) | |
# for event in handler.iter_events(with_logs=True): | |
# if isinstance(event, fal_client.InProgress): | |
# print('Request in progress') | |
# print(event.logs) | |
# result = handler.get() | |
# print("result video: ====") | |
# print(result) | |
# result_url = result['video']['url'] | |
# return result_url | |
else: | |
raise ValueError("model_type must be text2image or image2image") | |
def load_replicate_model(model_name, model_type): | |
return ReplicateModel(model_name, model_type) | |
if __name__ == "__main__": | |
import replicate | |
import time | |
input = { | |
"seed": 1, | |
"width": 512, | |
"height": 512, | |
"grid_size": 1, | |
"prompt": "anime astronaut riding a horse on mars" | |
} | |
for name, address in Replicate_MODEl_NAME_MAP.items(): | |
print('*'*50) | |
print(name) | |
t1 = time.time() | |
output = replicate.run( | |
address, | |
input=input | |
) | |
# for item in output: | |
# print(item) | |
print(output) | |
t2 = time.time() | |
print(t2-t1) | |
print('*'*50) | |