bgaspra commited on
Commit
c02513c
·
verified ·
1 Parent(s): 64606e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -164
app.py CHANGED
@@ -1,177 +1,122 @@
1
- import os
2
- import requests
3
- from tqdm import tqdm
4
- from datasets import load_dataset
5
- import numpy as np
6
- from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
7
- from tensorflow.keras.preprocessing import image
8
- from sklearn.neighbors import NearestNeighbors
9
- import joblib
10
- from PIL import UnidentifiedImageError, Image
11
  import gradio as gr
12
- from tensorflow.keras.models import Sequential
13
- from tensorflow.keras.layers import Dense, Dropout
14
- from tensorflow.keras.preprocessing.text import Tokenizer
15
- from tensorflow.keras.preprocessing.sequence import pad_sequences
16
-
17
- # Load the dataset
18
- dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
19
-
20
- # Filter out NSFW content and null models
21
- dataset_filtered = dataset['train'].filter(
22
- lambda x: not x['nsfw'] and x['Model'] is not None and x['Model'].strip() != ''
23
- )
24
-
25
- # Take a subset of the filtered dataset
26
- subset_size = 2700
27
- dataset_subset = dataset_filtered.shuffle(seed=42).select(range(subset_size))
28
-
29
- # Directory to save images
30
- image_dir = 'civitai_images'
31
- os.makedirs(image_dir, exist_ok=True)
32
-
33
- # Load the ResNet50 model pretrained on ImageNet
34
- cnn_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
35
-
36
- # Text processing setup
37
- max_words = 10000 # Maximum number of words to keep
38
- max_len = 100 # Maximum length of each text sequence
39
-
40
- # Initialize and fit tokenizer on prompts
41
- tokenizer = Tokenizer(num_words=max_words)
42
- prompts = [sample['prompt'] for sample in dataset_subset]
43
- tokenizer.fit_on_texts(prompts)
44
-
45
- # Create MLP model for text processing
46
- def create_mlp_model(input_dim):
47
- model = Sequential([
48
- Dense(256, activation='relu', input_dim=input_dim),
49
- Dropout(0.3),
50
- Dense(128, activation='relu'),
51
- Dropout(0.2),
52
- Dense(64, activation='relu'),
53
- Dense(32, activation='relu')
54
- ])
55
- return model
56
-
57
- # Function to extract text features
58
- def extract_text_features(prompt):
59
- # Convert text to sequence and pad
60
- sequence = tokenizer.texts_to_sequences([prompt])
61
- padded = pad_sequences(sequence, maxlen=max_len)
62
- # Get features from MLP
63
- return mlp_model.predict(padded)
64
-
65
- # Function to extract image features
66
- def extract_image_features(img_path, model):
67
- img = image.load_img(img_path, target_size=(224, 224))
68
- img_array = image.img_to_array(img)
69
- img_array = np.expand_dims(img_array, axis=0)
70
- img_array = preprocess_input(img_array)
71
- features = model.predict(img_array)
72
- return features.flatten()
73
-
74
- # Prepare text data
75
- text_sequences = tokenizer.texts_to_sequences(prompts)
76
- padded_sequences = pad_sequences(text_sequences, maxlen=max_len)
77
-
78
- # Create and train MLP model
79
- mlp_model = create_mlp_model(max_len)
80
- mlp_model.compile(optimizer='adam', loss='mse')
81
- mlp_model.fit(padded_sequences, padded_sequences, epochs=5, batch_size=32, validation_split=0.2)
82
-
83
- # Extract features for both images and text
84
- image_features = []
85
- text_features = []
86
- image_paths = []
87
- model_names = []
88
-
89
- for sample in tqdm(dataset_subset):
90
- img_url = sample['url']
91
- model_name = sample['Model']
92
- prompt = sample['prompt']
93
-
94
- img_path = os.path.join(image_dir, os.path.basename(img_url))
95
-
96
- try:
97
- # Download and process image
98
- response = requests.get(img_url)
99
- response.raise_for_status()
100
 
101
- if 'image' not in response.headers['Content-Type']:
102
- raise ValueError("URL does not contain an image")
103
-
104
- with open(img_path, 'wb') as f:
105
- f.write(response.content)
106
-
107
- # Extract image features
108
- img_features = extract_image_features(img_path, cnn_model)
109
 
110
- # Extract text features
111
- txt_features = extract_text_features(prompt)
 
112
 
113
- # Store features and metadata
114
- image_features.append(img_features)
115
- text_features.append(txt_features.flatten())
116
- image_paths.append(img_path)
117
- model_names.append(model_name)
 
 
 
118
 
119
- except (UnidentifiedImageError, requests.exceptions.RequestException) as e:
120
- print(f"Error processing {img_url}: {e}")
121
- if os.path.exists(img_path):
122
- os.remove(img_path)
123
-
124
- # Convert features to numpy arrays
125
- image_features = np.array(image_features)
126
- text_features = np.array(text_features)
127
-
128
- # Combine image and text features
129
- combined_features = np.concatenate([image_features, text_features], axis=1)
130
-
131
- # Build the NearestNeighbors model
132
- nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(combined_features)
133
-
134
- # Save models and features
135
- joblib.dump(nbrs, 'nearest_neighbors_model.pkl')
136
- joblib.dump(mlp_model, 'mlp_model.pkl')
137
- joblib.dump(tokenizer, 'tokenizer.pkl')
138
- np.save('combined_features.npy', combined_features)
139
- np.save('image_paths.npy', image_paths)
140
- np.save('model_names.npy', model_names)
141
-
142
- # Function to get recommendations
143
- def get_recommendations(img, prompt="", n_neighbors=5):
144
- # Process input image
145
- img_path = "temp_input_image.png"
146
- img.save(img_path)
147
- img_features = extract_image_features(img_path, cnn_model)
148
-
149
- # Process input text
150
- txt_features = extract_text_features(prompt)
151
-
152
- # Combine features
153
- input_features = np.concatenate([img_features, txt_features.flatten()])
154
-
155
- # Get recommendations
156
- distances, indices = nbrs.kneighbors([input_features])
157
-
158
- recommended_images = [image_paths[idx] for idx in indices.flatten()]
159
- recommended_model_names = [model_names[idx] for idx in indices.flatten()]
160
- recommended_distances = distances.flatten()
161
-
162
- return [(Image.open(img_path), f'{name}, Distance: {dist:.2f}')
163
- for img_path, name, dist in zip(recommended_images, recommended_model_names, recommended_distances)]
164
 
165
  # Gradio interface
166
  interface = gr.Interface(
167
  fn=get_recommendations,
168
- inputs=[
169
- gr.Image(type="pil"),
170
- gr.Textbox(label="Prompt")
171
- ],
172
  outputs=gr.Gallery(label="Recommended Images"),
173
- title="Image and Text Recommendation System",
174
- description="Upload an image and/or enter a prompt to get similar images with their model names and distances."
175
  )
176
 
177
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from torchvision import models
6
+ from transformers import BertTokenizer, BertModel
7
+ import pandas as pd
8
+ from datasets import load_dataset
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from sklearn.preprocessing import LabelEncoder
11
+
12
+ # Load dataset
13
+ dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
14
+
15
+ # Preprocess text data
16
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
17
+
18
+ class CustomDataset(Dataset):
19
+ def __init__(self, dataset):
20
+ self.dataset = dataset
21
+ self.transform = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor(),
24
+ ])
25
+ self.label_encoder = LabelEncoder()
26
+ self.labels = self.label_encoder.fit_transform(dataset['Model'])
27
+
28
+ def __len__(self):
29
+ return len(self.dataset)
30
+
31
+ def __getitem__(self, idx):
32
+ image = self.transform(self.dataset[idx]['image'])
33
+ text = tokenizer(
34
+ self.dataset[idx]['prompt'],
35
+ padding='max_length',
36
+ truncation=True,
37
+ return_tensors='pt'
38
+ )
39
+ label = self.labels[idx]
40
+ return image, text, label
41
+
42
+ # Define CNN for image processing
43
+ class ImageModel(nn.Module):
44
+ def __init__(self):
45
+ super(ImageModel, self).__init__()
46
+ self.model = models.resnet18(pretrained=True)
47
+ self.model.fc = nn.Linear(self.model.fc.in_features, 512)
48
+
49
+ def forward(self, x):
50
+ return self.model(x)
51
+
52
+ # Define MLP for text processing
53
+ class TextModel(nn.Module):
54
+ def __init__(self):
55
+ super(TextModel, self).__init__()
56
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
57
+ self.fc = nn.Linear(768, 512)
58
+
59
+ def forward(self, x):
60
+ output = self.bert(**x)
61
+ return self.fc(output.pooler_output)
62
+
63
+ # Combined model
64
+ class CombinedModel(nn.Module):
65
+ def __init__(self):
66
+ super(CombinedModel, self).__init__()
67
+ self.image_model = ImageModel()
68
+ self.text_model = TextModel()
69
+ self.fc = nn.Linear(1024, len(dataset['Model']))
70
+
71
+ def forward(self, image, text):
72
+ image_features = self.image_model(image)
73
+ text_features = self.text_model(text)
74
+ combined = torch.cat((image_features, text_features), dim=1)
75
+ return self.fc(combined)
76
+
77
+ # Instantiate model
78
+ model = CombinedModel()
79
+
80
+ def get_recommendations(image):
81
+ model.eval()
82
+ with torch.no_grad():
83
+ # Process image
84
+ transform = transforms.Compose([
85
+ transforms.Resize((224, 224)),
86
+ transforms.ToTensor()
87
+ ])
88
+ image_tensor = transform(image).unsqueeze(0)
 
89
 
90
+ # Process text
91
+ text_input = tokenizer(
92
+ "Sample prompt",
93
+ return_tensors='pt',
94
+ padding=True,
95
+ truncation=True
96
+ )
 
97
 
98
+ # Get predictions
99
+ output = model(image_tensor, text_input)
100
+ scores, indices = torch.topk(output, 5)
101
 
102
+ # Prepare gallery output
103
+ recommendations = []
104
+ for idx, score in zip(indices[0], scores[0]):
105
+ sample_data = dataset[int(idx)]
106
+ recommendations.append({
107
+ 'image': sample_data['image'],
108
+ 'label': f"Model: {sample_data['Model']}\nScore: {score:.2f}"
109
+ })
110
 
111
+ return recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Gradio interface
114
  interface = gr.Interface(
115
  fn=get_recommendations,
116
+ inputs=gr.Image(type="pil"),
 
 
 
117
  outputs=gr.Gallery(label="Recommended Images"),
118
+ title="Image Recommendation System",
119
+ description="Upload an image and get similar images with their model names and distances."
120
  )
121
 
122
  if __name__ == "__main__":