ifmain commited on
Commit
b00d4f0
1 Parent(s): 57f243f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +30 -14
README.md CHANGED
@@ -13,24 +13,40 @@ base_model: Salesforce/blip-image-captioning-base
13
 
14
  # Example Usage
15
  ```python
16
- past_the_code
17
- ```
 
 
18
 
19
- ## Junk
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- This model contains references to lore, they can be removed as follows:
22
- ```python
23
- past_the_code
24
- ```
25
 
26
- ## Examples
 
27
 
28
- ```note
29
- paste there images table
30
- ```
31
 
 
 
 
 
 
 
 
 
 
32
 
33
- | Original | Original prompt | Generated prompt by image | Generated image |
34
- | -------- | --------------- | ------------------------- | --------------- |
35
- | pass | pass | pass | pass |
36
 
 
13
 
14
  # Example Usage
15
  ```python
16
+ import torch
17
+ import requests
18
+ from PIL import Image
19
+ from transformers import BlipProcessor, BlipForConditionalGeneration
20
 
21
+ def prepare(text):
22
+ text = text.replace('. ','.').replace(' .','.')
23
+ text = text.replace('< ','<').replace(' <','<')
24
+ text = text.replace('> ','>').replace(' >','>')
25
+ text = text.replace('( ','(').replace(' (','(')
26
+ text = text.replace(') ',')').replace(' )',')')
27
+ text = text.replace(': ',':').replace(' :',':')
28
+ text = text.replace('_ ','_').replace(' _','_')
29
+ text = text.replace(',(())','')
30
+ for i in range(10):
31
+ text = text.replace(')))','))').replace('(((','((')
32
+ return text
33
 
34
+ path_to_model = "blip-image2promt-stable-diffusion-v0.15"
 
 
 
35
 
36
+ processor = BlipProcessor.from_pretrained(path_to_model)
37
+ model = BlipForConditionalGeneration.from_pretrained(path_to_model, torch_dtype=torch.float16).to("cuda")
38
 
39
+ img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
40
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
 
41
 
42
+ # unconditional image captioning
43
+ inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
44
+
45
+ out = model.generate(**inputs, max_new_tokens=100)
46
+
47
+ out_txt = processor.decode(out[0], skip_special_tokens=True)
48
+
49
+ print(prepare(out_txt))
50
+ ```
51
 
 
 
 
52