StevenTang
commited on
Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LIVE-BART
|
2 |
+
The LIVE-BART model was proposed in [**Learning to Imagine: Visually-Augmented Natural Language Generation**](https://arxiv.org/pdf/2305.16944.pdf) by Tianyi Tang, Yushuo Chen, Yifan Du, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
|
3 |
+
|
4 |
+
The detailed information and instructions can be found [https://github.com/RUCAIBox/LIVE](https://github.com/RUCAIBox/LIVE).
|
5 |
+
|
6 |
+
**You should install the `transformers` at [https://github.com/RUCAIBox/LIVE](https://github.com/RUCAIBox/LIVE).**
|
7 |
+
|
8 |
+
```python
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from transformers import BartForConditionalGeneration, AutoModel
|
12 |
+
|
13 |
+
class LiveModel(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.model = BartForConditionalGeneration.from_pretrained('RUCAIBox/live-bart-base', image_fusion_encoder=True)
|
18 |
+
self.vision_model = AutoModel.from_pretrained('openai/clip-vit-base-patch32').vision_model
|
19 |
+
hidden_size = self.model.config.hidden_size
|
20 |
+
self.trans = nn.Sequential(
|
21 |
+
nn.Linear(self.vision_model.config.hidden_size, hidden_size * 4),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Linear(hidden_size * 4, hidden_size),
|
24 |
+
)
|
25 |
+
|
26 |
+
model = LiveModel()
|
27 |
+
trans = torch.load('trans.bart.pth')
|
28 |
+
model.trans.load_state_dict(trans)
|
29 |
+
|
30 |
+
# kwargs to model.forward() and model.generate()
|
31 |
+
# input_ids [batch_size, seq_len], same to hugging face
|
32 |
+
# attention_masks [batch_size, seq_len], same to hugging face
|
33 |
+
# labels [batch_size, seq_len], same to hugging face
|
34 |
+
# image_embeds [batch_size, image_num*patch_num, image_hidden_size], should be transfered using `trans`, image_num can be the sentence num of text, patch_num and image_hidden_size are 50 and 768 for openai/clip-vit-base-patch32, respectively
|
35 |
+
# images_mask [batch_size, seq_len, image_num], this is the mask in Figure 1, 1 represents the i-th word should attend to the j-th image
|
36 |
+
# images_mask_2d [batch_size, seq_len], 1 represents the i-th word should not be visually augmented, i.e., should not be attend to any image
|
37 |
+
|
38 |
+
```
|