Update README.md
Browse files
README.md
CHANGED
@@ -13,24 +13,40 @@ base_model: Salesforce/blip-image-captioning-base
|
|
13 |
|
14 |
# Example Usage
|
15 |
```python
|
16 |
-
|
17 |
-
|
|
|
|
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
```python
|
23 |
-
past_the_code
|
24 |
-
```
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
```
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
| Original | Original prompt | Generated prompt by image | Generated image |
|
34 |
-
| -------- | --------------- | ------------------------- | --------------- |
|
35 |
-
| pass | pass | pass | pass |
|
36 |
|
|
|
13 |
|
14 |
# Example Usage
|
15 |
```python
|
16 |
+
import torch
|
17 |
+
import requests
|
18 |
+
from PIL import Image
|
19 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
20 |
|
21 |
+
def prepare(text):
|
22 |
+
text = text.replace('. ','.').replace(' .','.')
|
23 |
+
text = text.replace('< ','<').replace(' <','<')
|
24 |
+
text = text.replace('> ','>').replace(' >','>')
|
25 |
+
text = text.replace('( ','(').replace(' (','(')
|
26 |
+
text = text.replace(') ',')').replace(' )',')')
|
27 |
+
text = text.replace(': ',':').replace(' :',':')
|
28 |
+
text = text.replace('_ ','_').replace(' _','_')
|
29 |
+
text = text.replace(',(())','')
|
30 |
+
for i in range(10):
|
31 |
+
text = text.replace(')))','))').replace('(((','((')
|
32 |
+
return text
|
33 |
|
34 |
+
path_to_model = "blip-image2promt-stable-diffusion-v0.15"
|
|
|
|
|
|
|
35 |
|
36 |
+
processor = BlipProcessor.from_pretrained(path_to_model)
|
37 |
+
model = BlipForConditionalGeneration.from_pretrained(path_to_model, torch_dtype=torch.float16).to("cuda")
|
38 |
|
39 |
+
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
|
40 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
|
|
|
41 |
|
42 |
+
# unconditional image captioning
|
43 |
+
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
|
44 |
+
|
45 |
+
out = model.generate(**inputs, max_new_tokens=100)
|
46 |
+
|
47 |
+
out_txt = processor.decode(out[0], skip_special_tokens=True)
|
48 |
+
|
49 |
+
print(prepare(out_txt))
|
50 |
+
```
|
51 |
|
|
|
|
|
|
|
52 |
|