File size: 5,009 Bytes
8868d43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ad8ff2
8868d43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
from PIL import Image
import torch
from transformers import (
    ViTFeatureExtractor, 
    ViTForImageClassification, 
    pipeline,
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
from diffusers import StableDiffusionPipeline

# Load models
@st.cache_resource
def load_models():
    age_model = ViTForImageClassification.from_pretrained('nateraw/vit-age-classifier')
    age_transforms = ViTFeatureExtractor.from_pretrained('nateraw/vit-age-classifier')
    
    gender_model = ViTForImageClassification.from_pretrained('rizvandwiki/gender-classification-2')
    gender_transforms = ViTFeatureExtractor.from_pretrained('rizvandwiki/gender-classification-2')
    
    emotion_model = ViTForImageClassification.from_pretrained('dima806/facial_emotions_image_detection')
    emotion_transforms = ViTFeatureExtractor.from_pretrained('dima806/facial_emotions_image_detection')
    
    object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
    
    action_model = ViTForImageClassification.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
    action_transforms = ViTFeatureExtractor.from_pretrained('rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224')
    
    prompt_enhancer_tokenizer = AutoTokenizer.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
    prompt_enhancer_model = AutoModelForSeq2SeqLM.from_pretrained("gokaygokay/Flux-Prompt-Enhance")
    prompt_enhancer = pipeline('text2text-generation',
                               model=prompt_enhancer_model,
                               tokenizer=prompt_enhancer_tokenizer,
                               repetition_penalty=1.2,
                               device="cpu")
    
    # Load BK-SDM-Tiny for image generation
    pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-tiny", torch_dtype=torch.float16)
    return (age_model, age_transforms, gender_model, gender_transforms, 
            emotion_model, emotion_transforms, object_detector, 
            action_model, action_transforms, prompt_enhancer, pipe)

models = load_models()
(age_model, age_transforms, gender_model, gender_transforms, 
 emotion_model, emotion_transforms, object_detector, 
 action_model, action_transforms, prompt_enhancer, pipe) = models

def predict(image, model, transforms):
    # Convert the image to RGB format if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Apply the transformations and predict
    inputs = transforms(images=[image], return_tensors='pt')
    output = model(**inputs)
    proba = output.logits.softmax(1)
    return proba.argmax(1).item()

def detect_attributes(image):
    age = predict(image, age_model, age_transforms)
    gender = predict(image, gender_model, gender_transforms)
    emotion = predict(image, emotion_model, emotion_transforms)
    action = predict(image, action_model, action_transforms)
    
    objects = object_detector(image)
    
    return {
        'age': age_model.config.id2label[age],
        'gender': gender_model.config.id2label[gender],
        'emotion': emotion_model.config.id2label[emotion],
        'action': action_model.config.id2label[action],
        'objects': [obj['label'] for obj in objects]
    }

def generate_prompt(attributes):
    prompt = f"A {attributes['age']} year old {attributes['gender']} person feeling {attributes['emotion']} "
    prompt += f"while {attributes['action']}. "
    if attributes['objects']:
        prompt += f"Image has {', '.join(attributes['objects'])}. "
    return prompt

def enhance_prompt(prompt):
    prefix = "enhance prompt: "
    enhanced = prompt_enhancer(prefix + prompt, max_length=256)
    return enhanced[0]['generated_text']

@st.cache_data
def generate_image(prompt):
    # Generate image from the prompt using the BK-SDM-Tiny model
    with torch.no_grad():
        image = pipe(prompt, num_inference_steps=50).images[0]
    return image

st.title("Image Attribute Detection and Image Generation")

uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    if st.button('Analyze Image'):
        with st.spinner('Detecting attributes...'):
            attributes = detect_attributes(image)

        st.write("Detected Attributes:")
        for key, value in attributes.items():
            st.write(f"{key.capitalize()}: {value}")

        with st.spinner('Generating prompt...'):
            initial_prompt = generate_prompt(attributes)
            enhanced_prompt = enhance_prompt(initial_prompt)
        
        st.write("Initial Prompt:")
        st.write(initial_prompt)
        st.write("Enhanced Prompt:")
        st.write(enhanced_prompt)

        with st.spinner('Generating image...'):
            generated_image = generate_image(enhanced_prompt)
        st.image(generated_image, caption='Generated Image', use_column_width=True)