AMfeta99's picture
Update app.py
adcd167 verified
from PIL import Image, ImageDraw, ImageFont
import tempfile
import gradio as gr
from smolagents import CodeAgent, InferenceClientModel
from smolagents import DuckDuckGoSearchTool, Tool
from diffusers import DiffusionPipeline
import torch
from smolagents import OpenAIServerModel
import os
from huggingface_hub import login
openai_key = os.environ.get("OPENAI_API_KEY")
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
else:
print("Warning: HF_TOKEN not set.")
if openai_key:
# Exemplo de como usar a OpenAI API key
print("OpenAI API key is set")
else:
print("Warning: OPENAI_API_KEY not set.")
print("HF_TOKEN set?", "Yes" if hf_token else "No")
print("OPENAI_API_KEY set?", "Yes" if openai_key else "No")
# =========================================================
# Utility functions
# =========================================================
def add_label_to_image(image, label):
draw = ImageDraw.Draw(image)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
font_size = 30
try:
font = ImageFont.truetype(font_path, font_size)
except:
font = ImageFont.load_default()
text_bbox = draw.textbbox((0, 0), label, font=font)
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
position = (image.width - text_width - 20, image.height - text_height - 20)
rect_margin = 10
rect_position = [
position[0] - rect_margin,
position[1] - rect_margin,
position[0] + text_width + rect_margin,
position[1] + text_height + rect_margin,
]
draw.rectangle(rect_position, fill=(0, 0, 0, 128))
draw.text(position, label, fill="white", font=font)
return image
def plot_and_save_agent_image(agent_image, label, save_path=None):
#pil_image = agent_image.to_raw()
pil_image = agent_image
labeled_image = add_label_to_image(pil_image, label)
#labeled_image.show()
if save_path:
labeled_image.save(save_path)
print(f"Image saved to {save_path}")
else:
print("No save path provided. Image not saved.")
def generate_prompts_for_object(object_name):
return {
"past": f"Show an old version of a {object_name} from its early days.",
"present": f"Show a {object_name} with current features/design/technology.",
"future": f"Show a futuristic version of a {object_name}, by predicting advanced features and futuristic design."
}
# =========================================================
# Tool and Agent Initialization
# =========================================================
image_generation_tool = Tool.from_space(
#"KingNish/Realtime-FLUX",
"black-forest-labs/FLUX.1-schnell",
#"AMfeta99/FLUX.1-schnell",
api_name="/infer",
name="image_generator",
description="Generate an image from a prompt"
)
search_tool = DuckDuckGoSearchTool()
#llm_engine = InferenceClientModel("Qwen/Qwen2.5-72B-Instruct")
llm_engine2 = InferenceClientModel("Qwen/Qwen2.5-Coder-32B-Instruct", provider="together")
# Inicialização do modelo OpenAI com smolagents
llm_engine = OpenAIServerModel(
model_id="gpt-4o-mini", # Exemplo: ajuste para o modelo OpenAI que deseja usar
api_base="https://api.openai.com/v1",
api_key=openai_key
)
agent = CodeAgent(tools=[image_generation_tool, search_tool], model=llm_engine)
# =========================================================
# Main logic for image generation
# =========================================================
from PIL import Image
def generate_object_history(object_name):
images = []
prompts = generate_prompts_for_object(object_name)
general_instruction = (
"Search the necessary information and features for the following prompt, "
"then generate an image of it."
)
image_paths = []
for time_period, prompt in prompts.items():
print(f"Generating {time_period} frame: {prompt}")
try:
result = agent.run(
general_instruction,
additional_args={"prompt": prompt,
"width": 256, # specify width
"height": 256, # specify height
"seed": 0, # optional seed
"randomize_seed": False, # optional
"num_inference_steps": 4 # optional
}
)
# result is tuple: (filepath, seed)
if isinstance(result, (list, tuple)):
image_filepath = result[0]
else:
image_filepath = result # fallback in case result is just a string
# Open the image from filepath
image = Image.open(image_filepath)
# Save the image to your naming convention
image_filename = f"{object_name}_{time_period}.png"
image.save(image_filename)
# Optional: call your plotting function (if needed)
plot_and_save_agent_image(image, f"{object_name} - {time_period.title()}", save_path=image_filename)
image_paths.append(image_filename)
images.append(image)
except Exception as e:
print(f"Agent failed on {time_period}: {e}")
continue
# Create GIF from generated images if any
gif_path = f"{object_name}_evolution.gif"
if images:
images[0].save(gif_path, save_all=True, append_images=images[1:], duration=1000, loop=0)
return image_paths, gif_path
# =========================================================
# Gradio Interface
# =========================================================
def create_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# TimeMetamorphy: An Object Evolution Generator")
gr.Markdown("""
Explore how everyday objects evolved over time. Enter an object name like "phone", "car", or "bicycle"
and see its past, present, and future visualized with AI!
""")
#gr.Markdown("<span style='color: red;'>Note: If you experience issues connecting to the API while using the HF Space, try running the tool in this Colab Notebook instead — it may resolve the issue. <a href='https://colab.research.google.com/drive/1aKBJWkRBKhW8VFEu8p1zaxJr9VDzPaRz?usp=sharing' target='_blank'>Open Notebook</a>.</span>")
gr.HTML("<p style='color: red; font-weight: bold;'>🚨 Note: If you experience issues connecting to the API (while using the HF Space), If that happens feel free to run the exact same app/code in this Colab Notebook (it solve the issue).<a href='https://colab.research.google.com/drive/1aKBJWkRBKhW8VFEu8p1zaxJr9VDzPaRz?usp=sharing' target='_blank' style='color: red; text-decoration: underline;'> Open Notebook</a>.</p>")
default_images = [
"car_past.png",
"car_present.png",
"car_future.png"
]
default_gif_path = "car_evolution.gif"
with gr.Row():
with gr.Column():
object_name_input = gr.Textbox(label="Enter an object name", placeholder="e.g. bicycle, car, phone")
generate_button = gr.Button("Generate Evolution")
image_gallery = gr.Gallery(label="Generated Images", columns=3, rows=1, value=default_images, type="filepath")
gif_output = gr.Image(label="Generated GIF", value=default_gif_path, type="filepath")
#image_gallery = gr.Gallery(label="Generated Images", columns=3, rows=1, type="filepath")
#gif_output = gr.Image(label="Generated GIF", type="filepath")
generate_button.click(fn=generate_object_history, inputs=[object_name_input], outputs=[image_gallery, gif_output])
return demo
# Launch the interface
demo = create_gradio_interface()
demo.launch()