bgaspra commited on
Commit
906c9b0
·
verified ·
1 Parent(s): 0365b37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -128
app.py CHANGED
@@ -1,143 +1,178 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import torchvision.transforms as transforms
5
- from torchvision import models
6
- from transformers import BertTokenizer, BertModel
7
- import pandas as pd
8
- from datasets import load_dataset
9
- from torch.utils.data import DataLoader, Dataset
10
- from sklearn.preprocessing import LabelEncoder
11
  import requests
12
- from PIL import Image
13
- from io import BytesIO
14
  import numpy as np
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Load dataset
17
- dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
18
 
19
- # Download and cache images
20
- def download_image(url):
21
- try:
22
- response = requests.get(url)
23
- img = Image.open(BytesIO(response.content))
24
- return img
25
- except:
26
- return None
27
-
28
- # Create image cache
29
- image_cache = {}
30
- for idx, item in enumerate(dataset):
31
- if idx % 100 == 0: # Status update
32
- print(f"Downloaded {idx} images")
33
- url = item['url']
34
- if url not in image_cache:
35
- img = download_image(url)
36
- if img is not None:
37
- image_cache[url] = img
38
-
39
- # Preprocess text data
40
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
41
-
42
- class CustomDataset(Dataset):
43
- def __init__(self, dataset, image_cache):
44
- self.dataset = dataset
45
- self.image_cache = image_cache
46
- self.transform = transforms.Compose([
47
- transforms.Resize((224, 224)),
48
- transforms.ToTensor(),
49
- ])
50
- self.label_encoder = LabelEncoder()
51
- self.labels = self.label_encoder.fit_transform(dataset['Model'])
52
-
53
- def __len__(self):
54
- return len(self.dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def __getitem__(self, idx):
57
- url = self.dataset[idx]['url']
58
- image = self.transform(self.image_cache[url])
59
- text = tokenizer(self.dataset[idx]['prompt'],
60
- padding='max_length',
61
- truncation=True,
62
- return_tensors='pt')
63
- label = self.labels[idx]
64
- return image, text, label
65
-
66
- # Model definitions remain the same
67
- class ImageModel(nn.Module):
68
- def __init__(self):
69
- super(ImageModel, self).__init__()
70
- self.model = models.resnet18(pretrained=True)
71
- self.model.fc = nn.Linear(self.model.fc.in_features, 512)
72
-
73
- def forward(self, x):
74
- return self.model(x)
75
-
76
- class TextModel(nn.Module):
77
- def __init__(self):
78
- super(TextModel, self).__init__()
79
- self.bert = BertModel.from_pretrained('bert-base-uncased')
80
- self.fc = nn.Linear(768, 512)
81
-
82
- def forward(self, x):
83
- output = self.bert(**x)
84
- return self.fc(output.pooler_output)
85
-
86
- class CombinedModel(nn.Module):
87
- def __init__(self):
88
- super(CombinedModel, self).__init__()
89
- self.image_model = ImageModel()
90
- self.text_model = TextModel()
91
- self.fc = nn.Linear(1024, len(dataset['Model']))
92
-
93
- def forward(self, image, text):
94
- image_features = self.image_model(image)
95
- text_features = self.text_model(text)
96
- combined = torch.cat((image_features, text_features), dim=1)
97
- return self.fc(combined)
98
-
99
- # Instantiate model
100
- model = CombinedModel()
101
-
102
- # Modified prediction function
103
- def get_recommendations(input_image):
104
- model.eval()
105
- with torch.no_grad():
106
- # Process input image
107
- transform = transforms.Compose([
108
- transforms.Resize((224, 224)),
109
- transforms.ToTensor()
110
- ])
111
- input_tensor = transform(input_image).unsqueeze(0)
112
 
113
- # Get dummy text input (since we're focusing on image similarity)
114
- text_input = tokenizer("", return_tensors='pt', padding=True, truncation=True)
 
 
 
 
 
 
115
 
116
- # Get model output
117
- output = model(input_tensor, text_input)
118
- scores, indices = torch.topk(output, 5)
119
 
120
- # Prepare gallery output
121
- gallery_images = []
122
- for idx in indices[0]:
123
- url = dataset[idx]['url']
124
- model_name = dataset[idx]['Model']
125
- score = scores[0][idx].item()
126
-
127
- # Get image from cache
128
- if url in image_cache:
129
- gallery_images.append((image_cache[url], f"{model_name}\nScore: {score:.2f}"))
130
 
131
- return gallery_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # Set up Gradio interface
134
  interface = gr.Interface(
135
  fn=get_recommendations,
136
- inputs=gr.Image(type="pil"),
 
 
 
137
  outputs=gr.Gallery(label="Recommended Images"),
138
- title="Image Recommendation System",
139
- description="Upload an image and get similar images with their model names and distances."
140
  )
141
 
142
- # Launch the app
143
- interface.launch()
 
1
+ import os
 
 
 
 
 
 
 
 
 
2
  import requests
3
+ from tqdm import tqdm
4
+ from datasets import load_dataset
5
  import numpy as np
6
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
7
+ from tensorflow.keras.preprocessing import image
8
+ from sklearn.neighbors import NearestNeighbors
9
+ import joblib
10
+ from PIL import UnidentifiedImageError, Image
11
+ import gradio as gr
12
+ from tensorflow.keras.models import Sequential
13
+ from tensorflow.keras.layers import Dense, Dropout
14
+ from tensorflow.keras.preprocessing.text import Tokenizer
15
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
16
 
17
+ # Load the dataset
18
+ dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
19
 
20
+ # Filter out NSFW content and null models
21
+ dataset_filtered = dataset['train'].filter(
22
+ lambda x: not x['nsfw'] and x['Model'] is not None and x['Model'].strip() != ''
23
+ )
24
+
25
+ # Take a subset of the filtered dataset
26
+ subset_size = 2700
27
+ dataset_subset = dataset_filtered.shuffle(seed=42).select(range(subset_size))
28
+
29
+ # Directory to save images
30
+ image_dir = 'civitai_images'
31
+ os.makedirs(image_dir, exist_ok=True)
32
+
33
+ # Load the ResNet50 model pretrained on ImageNet
34
+ cnn_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
35
+
36
+ # Text processing setup
37
+ max_words = 10000 # Maximum number of words to keep
38
+ max_len = 100 # Maximum length of each text sequence
39
+
40
+ # Initialize and fit tokenizer on prompts
41
+ tokenizer = Tokenizer(num_words=max_words)
42
+ prompts = [sample['prompt'] for sample in dataset_subset]
43
+ tokenizer.fit_on_texts(prompts)
44
+
45
+ # Create MLP model for text processing
46
+ def create_mlp_model(input_dim):
47
+ model = Sequential([
48
+ Dense(256, activation='relu', input_dim=input_dim),
49
+ Dropout(0.3),
50
+ Dense(128, activation='relu'),
51
+ Dropout(0.2),
52
+ Dense(64, activation='relu'),
53
+ Dense(32, activation='relu')
54
+ ])
55
+ return model
56
+
57
+ # Function to extract text features
58
+ def extract_text_features(prompt):
59
+ # Convert text to sequence and pad
60
+ sequence = tokenizer.texts_to_sequences([prompt])
61
+ padded = pad_sequences(sequence, maxlen=max_len)
62
+ # Get features from MLP
63
+ return mlp_model.predict(padded)
64
+
65
+ # Function to extract image features
66
+ def extract_image_features(img_path, model):
67
+ img = image.load_img(img_path, target_size=(224, 224))
68
+ img_array = image.img_to_array(img)
69
+ img_array = np.expand_dims(img_array, axis=0)
70
+ img_array = preprocess_input(img_array)
71
+ features = model.predict(img_array)
72
+ return features.flatten()
73
+
74
+ # Prepare text data
75
+ text_sequences = tokenizer.texts_to_sequences(prompts)
76
+ padded_sequences = pad_sequences(text_sequences, maxlen=max_len)
77
+
78
+ # Create and train MLP model
79
+ mlp_model = create_mlp_model(max_len)
80
+ mlp_model.compile(optimizer='adam', loss='mse')
81
+ mlp_model.fit(padded_sequences, padded_sequences, epochs=5, batch_size=32, validation_split=0.2)
82
+
83
+ # Extract features for both images and text
84
+ image_features = []
85
+ text_features = []
86
+ image_paths = []
87
+ model_names = []
88
+
89
+ for sample in tqdm(dataset_subset):
90
+ img_url = sample['url']
91
+ model_name = sample['Model']
92
+ prompt = sample['prompt']
93
 
94
+ img_path = os.path.join(image_dir, os.path.basename(img_url))
95
+
96
+ try:
97
+ # Download and process image
98
+ response = requests.get(img_url)
99
+ response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ if 'image' not in response.headers['Content-Type']:
102
+ raise ValueError("URL does not contain an image")
103
+
104
+ with open(img_path, 'wb') as f:
105
+ f.write(response.content)
106
+
107
+ # Extract image features
108
+ img_features = extract_image_features(img_path, cnn_model)
109
 
110
+ # Extract text features
111
+ txt_features = extract_text_features(prompt)
 
112
 
113
+ # Store features and metadata
114
+ image_features.append(img_features)
115
+ text_features.append(txt_features.flatten())
116
+ image_paths.append(img_path)
117
+ model_names.append(model_name)
 
 
 
 
 
118
 
119
+ except (UnidentifiedImageError, requests.exceptions.RequestException) as e:
120
+ print(f"Error processing {img_url}: {e}")
121
+ if os.path.exists(img_path):
122
+ os.remove(img_path)
123
+
124
+ # Convert features to numpy arrays
125
+ image_features = np.array(image_features)
126
+ text_features = np.array(text_features)
127
+
128
+ # Combine image and text features
129
+ combined_features = np.concatenate([image_features, text_features], axis=1)
130
+
131
+ # Build the NearestNeighbors model
132
+ nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(combined_features)
133
+
134
+ # Save models and features
135
+ joblib.dump(nbrs, 'nearest_neighbors_model.pkl')
136
+ joblib.dump(mlp_model, 'mlp_model.pkl')
137
+ joblib.dump(tokenizer, 'tokenizer.pkl')
138
+ np.save('combined_features.npy', combined_features)
139
+ np.save('image_paths.npy', image_paths)
140
+ np.save('model_names.npy', model_names)
141
+
142
+ # Function to get recommendations
143
+ def get_recommendations(img, prompt="", n_neighbors=5):
144
+ # Process input image
145
+ img_path = "temp_input_image.png"
146
+ img.save(img_path)
147
+ img_features = extract_image_features(img_path, cnn_model)
148
+
149
+ # Process input text
150
+ txt_features = extract_text_features(prompt)
151
+
152
+ # Combine features
153
+ input_features = np.concatenate([img_features, txt_features.flatten()])
154
+
155
+ # Get recommendations
156
+ distances, indices = nbrs.kneighbors([input_features])
157
+
158
+ recommended_images = [image_paths[idx] for idx in indices.flatten()]
159
+ recommended_model_names = [model_names[idx] for idx in indices.flatten()]
160
+ recommended_distances = distances.flatten()
161
+
162
+ return [(Image.open(img_path), f'{name}, Distance: {dist:.2f}')
163
+ for img_path, name, dist in zip(recommended_images, recommended_model_names, recommended_distances)]
164
 
165
+ # Gradio interface
166
  interface = gr.Interface(
167
  fn=get_recommendations,
168
+ inputs=[
169
+ gr.Image(type="pil"),
170
+ gr.Textbox(label="Prompt")
171
+ ],
172
  outputs=gr.Gallery(label="Recommended Images"),
173
+ title="Image and Text Recommendation System",
174
+ description="Upload an image and/or enter a prompt to get similar images with their model names and distances."
175
  )
176
 
177
+ if __name__ == "__main__":
178
+ interface.launch()