Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## CLIP-Vit-Bert-Chinese pretrained model
|
3 |
+
这是中文版本的CLIP预训练模型,基于LiT-tuning(Locked-image Text tuning)的策略,使用140万中文图文对数据进行多模态对比学习预训练。
|
4 |
+
|
5 |
+
Github: [CLIP-Chinese](https://github.com/yangjianxin1/CLIP-Chinese)
|
6 |
+
|
7 |
+
Bolg: [CLIP-Chinese:中文多模态对比学习CLIP预训练模型]()
|
8 |
+
|
9 |
+
## Model and Training Detail
|
10 |
+
该模型主要由文本编码器与图像编码器组成,其中文本编码器为Bert,图像编码器为Vit,我们将该模型称为BertCLIP模型。训练时,Bert使用Langboat/mengzi-bert-base的权重进行初始化,Vit使用openai/clip-vit-large-patch32
|
11 |
+
的权重进行初始化。采用LiT-tuning(Locked-image Text tuning)的策略进行训练,也就是冻结Vit的权重,训练BertCLIP模型剩余的权重。
|
12 |
+
|
13 |
+
## Usage
|
14 |
+
首先将项目clone到本地,并且安装依赖包
|
15 |
+
```bash
|
16 |
+
git clone https://github.com/yangjianxin1/CLIP-Chinese
|
17 |
+
pip install -r requirements.txt
|
18 |
+
```
|
19 |
+
|
20 |
+
使用如下脚本,就可成功加载预训练权重,对图片和文本进行预处理,并且得到模型的输出
|
21 |
+
```python
|
22 |
+
from transformers import CLIPProcessor
|
23 |
+
from component.model import BertCLIPModel
|
24 |
+
from PIL import Image
|
25 |
+
import requests
|
26 |
+
|
27 |
+
model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M'
|
28 |
+
# 加载预训练模型权重
|
29 |
+
model = BertCLIPModel.from_pretrained(model_name_or_path)
|
30 |
+
CLIPProcessor.tokenizer_class = 'BertTokenizerFast'
|
31 |
+
# 初始化processor
|
32 |
+
processor = CLIPProcessor.from_pretrained(model_name_or_path)
|
33 |
+
# 预处理输入
|
34 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
35 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
36 |
+
inputs = processor(text=["一只小狗在摇尾巴", "一只小猪在吃饭"], images=image, return_tensors="pt", padding=True)
|
37 |
+
inputs.pop('token_type_ids') # 输入中不包含token_type_ids
|
38 |
+
|
39 |
+
outputs = model(**inputs)
|
40 |
+
|
41 |
+
# 对于每张图片,计算其与所有文本的相似度
|
42 |
+
logits_per_image = outputs.logits_per_image # image-text的相似度得分
|
43 |
+
probs = logits_per_image.softmax(dim=1) # 对分数进行归一化
|
44 |
+
|
45 |
+
# 对于每个文本,计算其与所有图片的相似度
|
46 |
+
logits_per_text = outputs.logits_per_text # text-image的相似度得分
|
47 |
+
probs = logits_per_text.softmax(dim=1) # 对分数进行归一化
|
48 |
+
|
49 |
+
# 获得文本编码
|
50 |
+
text_embeds = outputs.text_embeds
|
51 |
+
# 获得图像编码
|
52 |
+
image_embeds = outputs.image_embeds
|
53 |
+
```
|
54 |
+
|
55 |
+
单独加载图像编码器,进行下游任务
|
56 |
+
```python
|
57 |
+
from PIL import Image
|
58 |
+
import requests
|
59 |
+
from transformers import CLIPProcessor, CLIPVisionModel
|
60 |
+
|
61 |
+
model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M'
|
62 |
+
model = CLIPVisionModel.from_pretrained(model_name_or_path)
|
63 |
+
CLIPProcessor.tokenizer_class = 'BertTokenizerFast'
|
64 |
+
processor = CLIPProcessor.from_pretrained(model_name_or_path)
|
65 |
+
|
66 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
67 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
68 |
+
|
69 |
+
inputs = processor(images=image, return_tensors="pt")
|
70 |
+
|
71 |
+
outputs = model(**inputs)
|
72 |
+
last_hidden_state = outputs.last_hidden_state
|
73 |
+
pooled_output = outputs.pooler_output
|
74 |
+
```
|
75 |
+
|
76 |
+
单独加载文本编码器,进行下游任务
|
77 |
+
|
78 |
+
```python
|
79 |
+
from component.model import BertCLIPTextModel
|
80 |
+
from transformers import BertTokenizerFast
|
81 |
+
|
82 |
+
model_name_or_path = 'YeungNLP/clip-vit-bert-chinese-1M'
|
83 |
+
model = BertCLIPTextModel.from_pretrained(model_name_or_path)
|
84 |
+
tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path)
|
85 |
+
|
86 |
+
inputs = tokenizer(["一只小狗在摇尾巴", "一只小猪在吃饭"], padding=True, return_tensors="pt")
|
87 |
+
inputs.pop('token_type_ids') # 输入中不包含token_type_ids
|
88 |
+
|
89 |
+
outputs = model(**inputs)
|
90 |
+
last_hidden_state = outputs.last_hidden_state
|
91 |
+
pooled_output = outputs.pooler_output
|
92 |
+
```
|