StevenTang commited on
Commit
ec65910
·
verified ·
1 Parent(s): d761cfb

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -0
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LIVE-BART
2
+ The LIVE-BART model was proposed in [**Learning to Imagine: Visually-Augmented Natural Language Generation**](https://arxiv.org/pdf/2305.16944.pdf) by Tianyi Tang, Yushuo Chen, Yifan Du, Junyi Li, Wayne Xin Zhao and Ji-Rong Wen.
3
+
4
+ The detailed information and instructions can be found [https://github.com/RUCAIBox/LIVE](https://github.com/RUCAIBox/LIVE).
5
+
6
+ **You should install the `transformers` at [https://github.com/RUCAIBox/LIVE](https://github.com/RUCAIBox/LIVE).**
7
+
8
+ ```python
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import BartForConditionalGeneration, AutoModel
12
+
13
+ class LiveModel(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ self.model = BartForConditionalGeneration.from_pretrained('RUCAIBox/live-bart-base', image_fusion_encoder=True)
18
+ self.vision_model = AutoModel.from_pretrained('openai/clip-vit-base-patch32').vision_model
19
+ hidden_size = self.model.config.hidden_size
20
+ self.trans = nn.Sequential(
21
+ nn.Linear(self.vision_model.config.hidden_size, hidden_size * 4),
22
+ nn.ReLU(),
23
+ nn.Linear(hidden_size * 4, hidden_size),
24
+ )
25
+
26
+ model = LiveModel()
27
+ trans = torch.load('trans.bart.pth')
28
+ model.trans.load_state_dict(trans)
29
+
30
+ # kwargs to model.forward() and model.generate()
31
+ # input_ids [batch_size, seq_len], same to hugging face
32
+ # attention_masks [batch_size, seq_len], same to hugging face
33
+ # labels [batch_size, seq_len], same to hugging face
34
+ # image_embeds [batch_size, image_num*patch_num, image_hidden_size], should be transfered using `trans`, image_num can be the sentence num of text, patch_num and image_hidden_size are 50 and 768 for openai/clip-vit-base-patch32, respectively
35
+ # images_mask [batch_size, seq_len, image_num], this is the mask in Figure 1, 1 represents the i-th word should attend to the j-th image
36
+ # images_mask_2d [batch_size, seq_len], 1 represents the i-th word should not be visually augmented, i.e., should not be attend to any image
37
+
38
+ ```