AMfeta99's picture
Update app.py
da48de7 verified
raw
history blame
6.35 kB
from PIL import Image, ImageDraw, ImageFont
import tempfile
import gradio as gr
from smolagents import CodeAgent, InferenceClientModel
from smolagents import DuckDuckGoSearchTool, Tool
from diffusers import DiffusionPipeline
import torch
from huggingface_hub import login
import os
token = os.environ.get("HF_TOKEN")
if token:
login(token=token)
else:
print("Warning: HF_TOKEN not set. You may not be able to access private models or tools.")
# =========================================================
# Utility functions
# =========================================================
def add_label_to_image(image, label):
draw = ImageDraw.Draw(image)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
font_size = 30
try:
font = ImageFont.truetype(font_path, font_size)
except:
font = ImageFont.load_default()
text_bbox = draw.textbbox((0, 0), label, font=font)
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
position = (image.width - text_width - 20, image.height - text_height - 20)
rect_margin = 10
rect_position = [
position[0] - rect_margin,
position[1] - rect_margin,
position[0] + text_width + rect_margin,
position[1] + text_height + rect_margin,
]
draw.rectangle(rect_position, fill=(0, 0, 0, 128))
draw.text(position, label, fill="white", font=font)
return image
def plot_and_save_agent_image(agent_image, label, save_path=None):
pil_image = agent_image.to_raw()
labeled_image = add_label_to_image(pil_image, label)
#labeled_image.show()
if save_path:
labeled_image.save(save_path)
print(f"Image saved to {save_path}")
else:
print("No save path provided. Image not saved.")
def generate_prompts_for_object(object_name):
return {
"past": f"Show an old version of a {object_name} from its early days.",
"present": f"Show a {object_name} with current features/design/technology.",
"future": f"Show a futuristic version of a {object_name}, by predicting advanced features and futuristic design."
}
image_generation_tool = Tool.from_space(
#"KingNish/Realtime-FLUX",
"black-forest-labs/FLUX.1-schnell",
api_name="/infer",
name="image_generator",
description="Generate an image from a prompt"
)
# =========================================================
# Tool and Agent Initialization
# =========================================================
search_tool = DuckDuckGoSearchTool()
#llm_engine = InferenceClientModel("Qwen/Qwen2.5-72B-Instruct")
llm_engine = InferenceClientModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[image_generation_tool, search_tool], model=llm_engine)
# =========================================================
# Main logic for image generation
# =========================================================
def generate_object_history(object_name):
images = []
prompts = generate_prompts_for_object(object_name)
labels = {
"past": f"{object_name} - Past",
"present": f"{object_name} - Present",
"future": f"{object_name} - Future"
}
general_instruction = (
"Search the necessary information and features for the following prompt, "
"then generate an image of it."
)
for time_period, prompt in prompts.items():
print(f"Generating {time_period} frame: {prompt}")
#result = agent.run(prompt)
try:
result = agent.run(
general_instruction,
additional_args={"user_prompt": prompt}
)
if isinstance(result, (list, tuple)):
result = result[0]
image = result.to_raw()
except Exception as e:
print(f"Agent failed on {time_period}: {e}")
continue
images.append(result.to_raw())
image_filename = f"{object_name}_{time_period}.png"
plot_and_save_agent_image(result, labels[time_period], save_path=image_filename)
gif_path = f"{object_name}_evolution.gif"
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=1000, loop=0)
#return [
# f"{object_name}_past.png",
# f"{object_name}_present.png",
# f"{object_name}_future.png"], gif_path
return [(f"{object_name}_past.png", labels["past"]),
(f"{object_name}_present.png", labels["present"]),
(f"{object_name}_future.png", labels["future"])], gif_path
#return images, gif_path
# =========================================================
# Gradio Interface
# =========================================================
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# TimeMetamorphy: An Object Evolution Generator")
gr.Markdown("""
Explore how everyday objects evolved over time. Enter an object name like "phone", "car", or "bicycle"
and see its past, present, and future visualized with AI!
""")
default_images = [
("car_past.png", "Car - Past"),
("car_present.png", "Car - Present"),
("car_future.png", "Car - Future")
]
#default_images = [
# "car_past.png",
# "car_present.png",
# "car_future.png"
# ]
default_gif_path = "car_evolution.gif"
with gr.Row():
with gr.Column():
object_name_input = gr.Textbox(label="Enter an object name", placeholder="e.g. bicycle, car, phone")
generate_button = gr.Button("Generate Evolution")
image_gallery = gr.Gallery(label="Generated Images", columns=3, rows=1, value=default_images, type="filepath" )
#image_gallery = gr.Gallery(label="Generated Images", columns=3, rows=1, type="filepath")
gif_output = gr.Image(label="Generated GIF", value=default_gif_path)
generate_button.click(fn=generate_object_history, inputs=[object_name_input], outputs=[image_gallery, gif_output])
return demo
# =========================================================
# Run the app
# =========================================================
demo = create_gradio_interface()
demo.launch(share=True)