Narenameme commited on
Commit
41b56e7
·
verified ·
1 Parent(s): ebc0826

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import XLNetTokenizer, XLNetModel
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+ # TextEncoder class
8
+ class TextEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.transformer = XLNetModel.from_pretrained("xlnet-base-cased")
12
+
13
+ def forward(self, input_ids, token_type_ids, attention_mask):
14
+ hidden = self.transformer(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).last_hidden_state
15
+ context = hidden.mean(dim=1)
16
+ context = context.view(*context.shape, 1, 1)
17
+ return context
18
+
19
+ # Generator class
20
+ class Generator(nn.Module):
21
+ def __init__(self, nz=100, ngf=64, nt=768, nc=3):
22
+ super().__init__()
23
+ self.layer1 = nn.Sequential(
24
+ nn.ConvTranspose2d(nz+nt, ngf*8, 4, 1, 0, bias=False),
25
+ nn.BatchNorm2d(ngf*8)
26
+ )
27
+ self.layer2 = nn.Sequential(
28
+ nn.Conv2d(ngf*8, ngf*2, 1, 1),
29
+ nn.Dropout2d(inplace=True),
30
+ nn.BatchNorm2d(ngf*2),
31
+ nn.ReLU(True)
32
+ )
33
+ self.layer3 = nn.Sequential(
34
+ nn.Conv2d(ngf*2, ngf*2, 3, 1, 1),
35
+ nn.Dropout2d(inplace=True),
36
+ nn.BatchNorm2d(ngf*2),
37
+ nn.ReLU(True)
38
+ )
39
+ self.layer4 = nn.Sequential(
40
+ nn.Conv2d(ngf*2, ngf*8, 3, 1, 1),
41
+ nn.Dropout2d(inplace=True),
42
+ nn.BatchNorm2d(ngf*8),
43
+ nn.ReLU(True)
44
+ )
45
+ self.layer5 = nn.Sequential(
46
+ nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
47
+ nn.BatchNorm2d(ngf*4),
48
+ nn.ReLU(True)
49
+ )
50
+ self.layer6 = nn.Sequential(
51
+ nn.Conv2d(ngf*4, ngf, 1, 1),
52
+ nn.Dropout2d(inplace=True),
53
+ nn.BatchNorm2d(ngf),
54
+ nn.ReLU(True)
55
+ )
56
+ self.layer7 = nn.Sequential(
57
+ nn.Conv2d(ngf, ngf, 3, 1, 1),
58
+ nn.Dropout2d(inplace=True),
59
+ nn.BatchNorm2d(ngf),
60
+ nn.ReLU(True)
61
+ )
62
+ self.layer8 = nn.Sequential(
63
+ nn.Conv2d(ngf, ngf*4, 3, 1, 1),
64
+ nn.Dropout2d(inplace=True),
65
+ nn.BatchNorm2d(ngf*4),
66
+ nn.ReLU(True)
67
+ )
68
+ self.layer9 = nn.Sequential(
69
+ nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
70
+ nn.BatchNorm2d(ngf*2),
71
+ nn.ReLU(True)
72
+ )
73
+ self.layer10 = nn.Sequential(
74
+ nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
75
+ nn.BatchNorm2d(ngf),
76
+ nn.ReLU(True)
77
+ )
78
+ self.layer11 = nn.Sequential(
79
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
80
+ nn.Tanh()
81
+ )
82
+
83
+ def forward(self, noise, encoded_text):
84
+ x = torch.cat([noise, encoded_text], dim=1)
85
+ x = self.layer1(x)
86
+ x = self.layer2(x)
87
+ x = self.layer3(x)
88
+ x = self.layer4(x)
89
+ x = self.layer5(x)
90
+ x = self.layer6(x)
91
+ x = self.layer7(x)
92
+ x = self.layer8(x)
93
+ x = self.layer9(x)
94
+ x = self.layer10(x)
95
+ x = self.layer11(x)
96
+ return x
97
+
98
+
99
+ # Load the model and tokenizer
100
+ model_path = "./checkpoint.pth"
101
+ tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
102
+ text_encoder = XLNetModel.from_pretrained('xlnet-base-cased')
103
+ model = Generator()
104
+ model_state_dict = torch.load(model_path, map_location="cpu")
105
+ generator = model_state_dict['models']['generator']
106
+ model.load_state_dict(generator)
107
+
108
+ text_encoder.to("cpu")
109
+ model.to("cpu")
110
+ model.eval()
111
+
112
+ # Functions to encode text and generate image
113
+ def encode_text(text):
114
+ text_encoder_model = TextEncoder()
115
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
116
+ encoded_text = text_encoder_model(**inputs)
117
+ return encoded_text
118
+
119
+ def generate_image(text):
120
+ encoded_text = encode_text(text)
121
+ noise = torch.randn((1, 100, 1, 1), device="cpu")
122
+ with torch.no_grad():
123
+ generated_image = model(noise, encoded_text).detach().squeeze().cpu()
124
+ gen_image_np = generated_image.numpy()
125
+ gen_image_np = np.transpose(gen_image_np, (1, 2, 0)) # Change from CHW to HWC
126
+ gen_image_np = (gen_image_np - gen_image_np.min()) / (gen_image_np.max() - gen_image_np.min()) # Normalize to [0, 1]
127
+ gen_image_np = (gen_image_np * 255).astype(np.uint8)
128
+ return gen_image_np
129
+
130
+ # Gradio interface
131
+ inputs = gr.inputs.Textbox(label="Enter a flower-related description", default="A beautiful red rose")
132
+ outputs = gr.outputs.Image(type="numpy", label="Generated Flower Image")
133
+
134
+ gr.Interface(fn=generate_image, inputs=inputs, outputs=outputs, title="Flower Image Generator", description="Enter a description of a flower to generate an image.").launch()