File size: 4,922 Bytes
860760c 1509d22 4bdfa75 1fd7a59 e38cd3d 5167fb6 860760c e38cd3d a568cb6 e38cd3d a568cb6 e38cd3d a568cb6 e38cd3d a568cb6 e38cd3d a568cb6 84abbea a568cb6 4bdfa75 e38cd3d 4bdfa75 a568cb6 4bdfa75 a568cb6 860760c a568cb6 e38cd3d a568cb6 e38cd3d a568cb6 e38cd3d 4bdfa75 a568cb6 1fd7a59 860760c 3a2a66c e38cd3d 3a2a66c e38cd3d 4bdfa75 860760c 4bdfa75 860760c fad0d14 1fd7a59 a568cb6 e38cd3d fad0d14 860760c fad0d14 4bdfa75 fad0d14 a568cb6 e38cd3d a568cb6 4bdfa75 e38cd3d 4bdfa75 e38cd3d 4bdfa75 e38cd3d 4bdfa75 5167fb6 4bdfa75 5420ab6 fad0d14 4bdfa75 fad0d14 4bdfa75 860760c 5167fb6 a568cb6 e38cd3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from huggingface_hub import InferenceClient
from langchain_community.tools import DuckDuckGoSearchResults
from langchain.agents import create_react_agent, AgentExecutor
from langchain_core.tools import BaseTool
from pydantic import Field
from PIL import Image, ImageDraw, ImageFont
from functools import lru_cache
import gradio as gr
from io import BytesIO
import os
# === Setup Inference Clients ===
# Use your Hugging Face token if necessary:
# client = InferenceClient(repo_id="model", token="YOUR_HF_TOKEN")
image_client = InferenceClient("m-ric/text-to-image")
text_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
# === LangChain wrapper using InferenceClient for text generation ===
class InferenceClientLLM(BaseTool):
name: str = "inference_text_generator"
description: str = "Generate text using HF Inference API."
client: InferenceClient = Field(default=text_client, exclude=True)
def _run(self, prompt: str) -> str:
print(f"[LLM] Generating text for prompt: {prompt}")
response = self.client.text_generation(prompt)
# response is usually a dict with 'generated_text'
return response.get("generated_text", "")
def _arun(self, prompt: str):
raise NotImplementedError("Async not supported.")
# === Image generation tool ===
class TextToImageTool(BaseTool):
name: str = "text_to_image"
description: str = "Generate an image from a text prompt."
client: InferenceClient = Field(default=image_client, exclude=True)
def _run(self, prompt: str) -> Image.Image:
print(f"[Image Tool] Generating image for prompt: {prompt}")
image_bytes = self.client.text_to_image(prompt)
return Image.open(BytesIO(image_bytes))
def _arun(self, prompt: str):
raise NotImplementedError("Async not supported.")
# === Initialize tools ===
text_to_image_tool = TextToImageTool()
text_gen_tool = InferenceClientLLM()
search_tool = DuckDuckGoSearchResults()
# === Create agent ===
agent = create_react_agent(llm=text_gen_tool, tools=[text_to_image_tool, search_tool])
agent_executor = AgentExecutor(agent=agent, tools=[text_to_image_tool, search_tool], verbose=True)
# === Image labeling ===
def add_label_to_image(image, label):
draw = ImageDraw.Draw(image)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
try:
font = ImageFont.truetype(font_path, 30)
except:
font = ImageFont.load_default()
text_width, text_height = draw.textsize(label, font=font)
position = (image.width - text_width - 20, image.height - text_height - 20)
rect_position = [position[0] - 10, position[1] - 10, position[0] + text_width + 10, position[1] + text_height + 10]
draw.rectangle(rect_position, fill=(0, 0, 0, 128))
draw.text(position, label, fill="white", font=font)
return image
# === Prompt generation with caching ===
@lru_cache(maxsize=128)
def generate_prompts_for_object(object_name):
return {
"past": f"Show an old version of a {object_name} from its early days.",
"present": f"Show a {object_name} with current features/design/technology.",
"future": f"Show a futuristic version of a {object_name}, predicting future features/designs.",
}
# === Cache generated images ===
@lru_cache(maxsize=64)
def generate_image_for_prompt(prompt, label):
img = text_to_image_tool._run(prompt)
return add_label_to_image(img, label)
# === Main generation function ===
def generate_object_history(object_name: str):
prompts = generate_prompts_for_object(object_name)
images = []
file_paths = []
for period, prompt in prompts.items():
label = f"{object_name} - {period.capitalize()}"
labeled_image = generate_image_for_prompt(prompt, label)
file_path = f"/tmp/{object_name}_{period}.png"
labeled_image.save(file_path)
images.append((file_path, label))
file_paths.append(file_path)
# Create GIF
gif_path = f"/tmp/{object_name}_evolution.gif"
pil_images = [Image.open(p) for p in file_paths]
pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], duration=1000, loop=0)
return images, gif_path
# === Gradio UI ===
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# TimeMetamorphy: Evolution Visualizer")
with gr.Row():
with gr.Column():
object_input = gr.Textbox(label="Enter Object (e.g., car, phone)")
generate_button = gr.Button("Generate Evolution")
gallery = gr.Gallery(label="Generated Images").style(grid=3)
gif_display = gr.Image(label="Generated GIF")
generate_button.click(fn=generate_object_history, inputs=object_input, outputs=[gallery, gif_display])
return demo
# === Launch app ===
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(share=True)
|