bgaspra commited on
Commit
e897bc2
·
verified ·
1 Parent(s): 67d3c78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -17
app.py CHANGED
@@ -14,8 +14,9 @@ dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:
14
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
17
  class CustomDataset(Dataset):
18
- def init(self, dataset):
19
  self.dataset = dataset
20
  self.transform = transforms.Compose([
21
  transforms.Resize((224, 224)),
@@ -23,48 +24,54 @@ class CustomDataset(Dataset):
23
  ])
24
  self.label_encoder = LabelEncoder()
25
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
26
- def len(self):
 
27
  return len(self.dataset)
28
- def getitem(self, idx):
 
29
  image = self.transform(self.dataset[idx]['image'])
30
  text = tokenizer(self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt')
31
  label = self.labels[idx]
32
  return image, text, label
33
-
34
  # Define CNN for image processing
35
  class ImageModel(nn.Module):
36
- def init(self):
37
- super(ImageModel, self).init()
38
  self.model = models.resnet18(pretrained=True)
39
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
 
40
  def forward(self, x):
41
  return self.model(x)
42
-
43
  # Define MLP for text processing
44
  class TextModel(nn.Module):
45
- def init(self):
46
- super(TextModel, self).init()
47
  self.bert = BertModel.from_pretrained('bert-base-uncased')
48
  self.fc = nn.Linear(768, 512)
 
49
  def forward(self, x):
50
- output = self.bert(x)
51
  return self.fc(output.pooler_output)
52
-
53
  # Combined model
54
  class CombinedModel(nn.Module):
55
- def init(self):
56
- super(CombinedModel, self).init()
57
  self.image_model = ImageModel()
58
  self.text_model = TextModel()
59
  self.fc = nn.Linear(1024, len(dataset['Model']))
 
60
  def forward(self, image, text):
61
  image_features = self.image_model(image)
62
  text_features = self.text_model(text)
63
  combined = torch.cat((image_features, text_features), dim=1)
64
  return self.fc(combined)
65
-
66
  # Instantiate model
67
  model = CombinedModel()
 
68
  # Define predict function
69
  def predict(image):
70
  model.eval()
@@ -72,16 +79,17 @@ def predict(image):
72
  image = transforms.ToTensor()(image).unsqueeze(0)
73
  image = transforms.Resize((224, 224))(image)
74
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
75
- output = model(image, textinput)
76
- , indices = torch.topk(output, 5)
77
  recommended_models = [dataset['Model'][i] for i in indices[0]]
78
  return recommended_models
79
-
80
  # Set up Gradio interface
81
  interface = gr.Interface(fn=predict,
82
  inputs=gr.Image(type="pil"),
83
  outputs=gr.Textbox(label="Recommended Models"),
84
  title="AI Image Model Recommender",
85
  description="Upload an AI-generated image to receive model recommendations.")
 
86
  # Launch the app
87
  interface.launch()
 
14
 
15
  # Preprocess text data
16
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
17
+
18
  class CustomDataset(Dataset):
19
+ def __init__(self, dataset):
20
  self.dataset = dataset
21
  self.transform = transforms.Compose([
22
  transforms.Resize((224, 224)),
 
24
  ])
25
  self.label_encoder = LabelEncoder()
26
  self.labels = self.label_encoder.fit_transform(dataset['Model'])
27
+
28
+ def __len__(self):
29
  return len(self.dataset)
30
+
31
+ def __getitem__(self, idx):
32
  image = self.transform(self.dataset[idx]['image'])
33
  text = tokenizer(self.dataset[idx]['prompt'], padding='max_length', truncation=True, return_tensors='pt')
34
  label = self.labels[idx]
35
  return image, text, label
36
+
37
  # Define CNN for image processing
38
  class ImageModel(nn.Module):
39
+ def __init__(self):
40
+ super(ImageModel, self).__init__()
41
  self.model = models.resnet18(pretrained=True)
42
  self.model.fc = nn.Linear(self.model.fc.in_features, 512)
43
+
44
  def forward(self, x):
45
  return self.model(x)
46
+
47
  # Define MLP for text processing
48
  class TextModel(nn.Module):
49
+ def __init__(self):
50
+ super(TextModel, self).__init__()
51
  self.bert = BertModel.from_pretrained('bert-base-uncased')
52
  self.fc = nn.Linear(768, 512)
53
+
54
  def forward(self, x):
55
+ output = self.bert(**x)
56
  return self.fc(output.pooler_output)
57
+
58
  # Combined model
59
  class CombinedModel(nn.Module):
60
+ def __init__(self):
61
+ super(CombinedModel, self).__init__()
62
  self.image_model = ImageModel()
63
  self.text_model = TextModel()
64
  self.fc = nn.Linear(1024, len(dataset['Model']))
65
+
66
  def forward(self, image, text):
67
  image_features = self.image_model(image)
68
  text_features = self.text_model(text)
69
  combined = torch.cat((image_features, text_features), dim=1)
70
  return self.fc(combined)
71
+
72
  # Instantiate model
73
  model = CombinedModel()
74
+
75
  # Define predict function
76
  def predict(image):
77
  model.eval()
 
79
  image = transforms.ToTensor()(image).unsqueeze(0)
80
  image = transforms.Resize((224, 224))(image)
81
  text_input = tokenizer("Sample prompt", return_tensors='pt', padding=True, truncation=True)
82
+ output = model(image, text_input)
83
+ _, indices = torch.topk(output, 5)
84
  recommended_models = [dataset['Model'][i] for i in indices[0]]
85
  return recommended_models
86
+
87
  # Set up Gradio interface
88
  interface = gr.Interface(fn=predict,
89
  inputs=gr.Image(type="pil"),
90
  outputs=gr.Textbox(label="Recommended Models"),
91
  title="AI Image Model Recommender",
92
  description="Upload an AI-generated image to receive model recommendations.")
93
+
94
  # Launch the app
95
  interface.launch()