File size: 4,922 Bytes
860760c
1509d22
4bdfa75
 
 
1fd7a59
e38cd3d
5167fb6
860760c
e38cd3d
 
a568cb6
 
 
e38cd3d
 
a568cb6
e38cd3d
a568cb6
 
 
 
 
e38cd3d
a568cb6
 
 
 
 
e38cd3d
a568cb6
 
84abbea
a568cb6
4bdfa75
 
 
e38cd3d
4bdfa75
 
a568cb6
4bdfa75
 
 
 
a568cb6
860760c
a568cb6
e38cd3d
a568cb6
e38cd3d
 
a568cb6
 
e38cd3d
4bdfa75
a568cb6
1fd7a59
 
860760c
3a2a66c
e38cd3d
3a2a66c
 
e38cd3d
4bdfa75
860760c
4bdfa75
860760c
fad0d14
1fd7a59
 
a568cb6
e38cd3d
fad0d14
860760c
fad0d14
 
4bdfa75
fad0d14
 
a568cb6
e38cd3d
 
 
 
 
a568cb6
4bdfa75
 
 
e38cd3d
 
4bdfa75
e38cd3d
 
 
 
 
 
 
 
 
 
 
4bdfa75
 
e38cd3d
4bdfa75
 
5167fb6
 
4bdfa75
5420ab6
fad0d14
 
4bdfa75
fad0d14
4bdfa75
 
 
 
860760c
5167fb6
 
a568cb6
e38cd3d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from huggingface_hub import InferenceClient
from langchain_community.tools import DuckDuckGoSearchResults
from langchain.agents import create_react_agent, AgentExecutor
from langchain_core.tools import BaseTool
from pydantic import Field
from PIL import Image, ImageDraw, ImageFont
from functools import lru_cache
import gradio as gr
from io import BytesIO
import os

# === Setup Inference Clients ===
# Use your Hugging Face token if necessary:
# client = InferenceClient(repo_id="model", token="YOUR_HF_TOKEN")

image_client = InferenceClient("m-ric/text-to-image")
text_client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")

# === LangChain wrapper using InferenceClient for text generation ===
class InferenceClientLLM(BaseTool):
    name: str = "inference_text_generator"
    description: str = "Generate text using HF Inference API."
    client: InferenceClient = Field(default=text_client, exclude=True)

    def _run(self, prompt: str) -> str:
        print(f"[LLM] Generating text for prompt: {prompt}")
        response = self.client.text_generation(prompt)
        # response is usually a dict with 'generated_text'
        return response.get("generated_text", "")

    def _arun(self, prompt: str):
        raise NotImplementedError("Async not supported.")

# === Image generation tool ===
class TextToImageTool(BaseTool):
    name: str = "text_to_image"
    description: str = "Generate an image from a text prompt."
    client: InferenceClient = Field(default=image_client, exclude=True)

    def _run(self, prompt: str) -> Image.Image:
        print(f"[Image Tool] Generating image for prompt: {prompt}")
        image_bytes = self.client.text_to_image(prompt)
        return Image.open(BytesIO(image_bytes))

    def _arun(self, prompt: str):
        raise NotImplementedError("Async not supported.")

# === Initialize tools ===
text_to_image_tool = TextToImageTool()
text_gen_tool = InferenceClientLLM()
search_tool = DuckDuckGoSearchResults()

# === Create agent ===
agent = create_react_agent(llm=text_gen_tool, tools=[text_to_image_tool, search_tool])
agent_executor = AgentExecutor(agent=agent, tools=[text_to_image_tool, search_tool], verbose=True)

# === Image labeling ===
def add_label_to_image(image, label):
    draw = ImageDraw.Draw(image)
    font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
    try:
        font = ImageFont.truetype(font_path, 30)
    except:
        font = ImageFont.load_default()

    text_width, text_height = draw.textsize(label, font=font)
    position = (image.width - text_width - 20, image.height - text_height - 20)
    rect_position = [position[0] - 10, position[1] - 10, position[0] + text_width + 10, position[1] + text_height + 10]
    draw.rectangle(rect_position, fill=(0, 0, 0, 128))
    draw.text(position, label, fill="white", font=font)
    return image

# === Prompt generation with caching ===
@lru_cache(maxsize=128)
def generate_prompts_for_object(object_name):
    return {
        "past": f"Show an old version of a {object_name} from its early days.",
        "present": f"Show a {object_name} with current features/design/technology.",
        "future": f"Show a futuristic version of a {object_name}, predicting future features/designs.",
    }

# === Cache generated images ===
@lru_cache(maxsize=64)
def generate_image_for_prompt(prompt, label):
    img = text_to_image_tool._run(prompt)
    return add_label_to_image(img, label)

# === Main generation function ===
def generate_object_history(object_name: str):
    prompts = generate_prompts_for_object(object_name)
    images = []
    file_paths = []

    for period, prompt in prompts.items():
        label = f"{object_name} - {period.capitalize()}"
        labeled_image = generate_image_for_prompt(prompt, label)

        file_path = f"/tmp/{object_name}_{period}.png"
        labeled_image.save(file_path)
        images.append((file_path, label))
        file_paths.append(file_path)

    # Create GIF
    gif_path = f"/tmp/{object_name}_evolution.gif"
    pil_images = [Image.open(p) for p in file_paths]
    pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], duration=1000, loop=0)

    return images, gif_path

# === Gradio UI ===
def create_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# TimeMetamorphy: Evolution Visualizer")

        with gr.Row():
            with gr.Column():
                object_input = gr.Textbox(label="Enter Object (e.g., car, phone)")
                generate_button = gr.Button("Generate Evolution")
                gallery = gr.Gallery(label="Generated Images").style(grid=3)
                gif_display = gr.Image(label="Generated GIF")

        generate_button.click(fn=generate_object_history, inputs=object_input, outputs=[gallery, gif_display])

    return demo

# === Launch app ===
if __name__ == "__main__":
    demo = create_gradio_interface()
    demo.launch(share=True)