autogen / app.py
tzintsunzu's picture
Create app.py
4a0d8b9 verified
import os
import pandas as pd
import torch
import gc
import re
import random
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
import gradio as gr
# Initialize the text generation pipeline with the pre-quantized 8-bit model
model_name = 'HuggingFaceTB/SmolLM-1.7B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1) # Use CPU
# Load the Stable Diffusion model
model_id = "stabilityai/stable-diffusion-2-1-base" # Smaller model
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to("cpu") # Use CPU
# Create a directory to save the generated images
output_dir = 'generated_images'
os.makedirs(output_dir, exist_ok=True)
os.chmod(output_dir, 0o777)
# Function to generate a detailed visual description prompt
def generate_description_prompt(user_prompt, user_examples):
prompt = f'generate enclosed in quotes in the format "<description>" description according to guidelines of {user_prompt} different from {user_examples}'
try:
generated_text = text_generator(prompt, max_length=150, num_return_sequences=1, truncation=True)[0]['generated_text']
match = re.search(r'"(.*?)"', generated_text)
if match:
generated_description = match.group(1).strip() # Capture the description between quotes
return f'"{generated_description}"'
else:
return None
except Exception as e:
print(f"Error generating description for prompt '{user_prompt}': {e}")
return None
# Seed words pool
seed_words = []
used_words = set()
def generate_description(user_prompt, user_examples_list):
seed_words.extend(user_examples_list)
# Select a subject that has not been used
available_subjects = [word for word in seed_words if word not in used_words]
if not available_subjects:
print("No more available subjects to use.")
return None, None
subject = random.choice(available_subjects)
generated_description = generate_description_prompt(user_prompt, subject)
if generated_description:
# Remove any offending symbols
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
# Print the generated description to the command line
print(f"Generated description for subject '{subject}': {clean_description}")
# Update used words and seed words
used_words.add(subject)
seed_words.append(clean_description.strip('"')) # Add the generated description to the seed bank array without quotes
return clean_description, subject
else:
return None, None
# Function to generate an image based on the description
def generate_image(description, seed=42):
prompt = f'detailed photorealistic full shot of {description}'
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=512,
height=512,
num_inference_steps=10, # Use 10 inference steps
generator=generator,
guidance_scale=7.5,
).images[0]
return image
# Gradio interface
def gradio_interface(user_prompt, user_examples):
user_examples_list = [example.strip().strip('"') for example in user_examples.split(',')]
generated_description, subject = generate_description(user_prompt, user_examples_list)
if generated_description:
# Generate image
image = generate_image(generated_description)
image_path = os.path.join(output_dir, f"image_{len(os.listdir(output_dir))}.png")
image.save(image_path)
os.chmod(image_path, 0o777)
return image, generated_description
else:
return None, "Failed to generate description."
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter the generation task or general thing you are looking for"),
gr.Textbox(lines=2, placeholder='Provide a few examples (enclosed in quotes and separated by commas)')
],
outputs=[
gr.Image(label="Generated Image"),
gr.Textbox(label="Generated Description")
],
title="Description and Image Generator",
description="Generate detailed descriptions and images based on your input."
)
iface.launch(server_name="0.0.0.0", server_port=7860)
# Clear GPU memory when the process is closed
def clear_gpu_memory():
torch.cuda.empty_cache()
gc.collect()
print("GPU memory cleared.")