SkyWork commited on
Commit
2e4e847
·
1 Parent(s): fb2458f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +156 -3
README.md CHANGED
@@ -1,3 +1,156 @@
1
- ---
2
- license: creativeml-openrail-m
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SkyPaint-Chinese-EN-v-1.0
2
+ #### SkyPaint是由奇点智源开发的中英双语文本生成图像的项目,目前还在持续更新优化中
3
+ - 项目地址: [SkyWorkAIGC-SkyPaint](https://github.com/SkyWorkAIGC/SkyPaint)
4
+
5
+ # 模型介绍
6
+ SkyPaint文本生成图片模型主要由两大部分组成,即提示词文本编码器模型和扩散模型两大部分。因此我们的优化也分为两步,首先基于[OpenAI-CLIP](https://github.com/openai/CLIP)优化了提示词文本编码器模型使得SkyPaint具有中英文识别能力,然后优化了扩散模型,使得SkyPaint具有现代艺术能力可以产生高质量图片。
7
+
8
+ # 模型功能
9
+ * 支持汉语和英文以及中英文混合提示词输入
10
+ * 支持生成现代艺术风格的高质量图片
11
+ * 支持stable_diffusion_1.x官方模型及相关微调模型的英文提示词
12
+ * 保留stable_diffusion提示词的使用习惯和方法
13
+
14
+ ### SkyCLIP模型简介
15
+ SkyCLIP是我们采用一种高效的训练中英双语CLIP模型的方法得到的CLIP模型,该方法仅需要使用文本数据即可实现对[OpenAI-CLIP](https://github.com/openai/CLIP)模型的高效蒸馏,大幅降低了数据门槛,同时训练所需算力要求相较于原始CLIP模型减少90%以上,方便开源社区可以进行复现/微调。该方法仅改变了OpenAI-CLIP的文本编码器,可搭配使用OpenAI-CLIP的图像编码器实现图文检索功能。
16
+
17
+ ### SkyCLIP训练数据来源
18
+ * 中英文机器翻译任务平行语料
19
+ * 联合国中英文平行语料
20
+ * [LAION](https://laion.ai/)中英文语料(部分)
21
+ * [Wukong](https://wukong-dataset.github.io/wukong-dataset/index.html)中文语料(部分)
22
+ * [AI-Challenger](https://github.com/AIChallenger)翻译任务中英文语料
23
+ * 古诗词中英文语料
24
+ * 提示词手册/魔法书中常见词组合而成的中英文语料
25
+
26
+ ### SkyCLIP训练方法
27
+ 将OpenAI-CLIP的text_encoder作为教师模型并且冻结参数,学生模型采用和教师模型同样大小的多语言BERT模型,训练时英文输入通过教师模型获取相应的t_en_hiddent_state,英文和中文分别通过学生模型获取相应s_en_hiddent_state,s_zh_hidden_state,采用l1、l2、cos距离等构造损失函数使得学生模型的中英文hiddent_state逐渐靠近教师模型的hiddent_state。由于平行语料的中文和英文存在天然的不等长性质,为了使得平行的中文和英文尽量接近,训练过程中我们还添加了中文解码器,使用学生模型的中英文hiddent_state作为解码器的hidden_state输入,通过翻译任务来辅助实现中文和英文的对齐目的。
28
+
29
+ ### SkyCLIP模型评估
30
+ 目前我们主要评估了SkyCLIP在[Flickr30K-CN](https://github.com/li-xirong/cross-lingual-cap)的zero-shot表现,主要对比了若干具备中文能力的相关开源模型,为确保对比的公平性,具有多个模型尺寸的我们均选取基于OpenAI-CLIP ViT-L/14尺寸的模型,我们评估的流程参考了[Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)所提供的评估脚本。
31
+
32
+ **Flickr30K-CN Retrieval**:
33
+ <table border="1" width="150%">
34
+ <tr align="center">
35
+ <th>Task</th><th colspan="3">Text-to-Image</th><th colspan="3">Image-to-Text</th>
36
+ <th rowspan="3">MR</th>
37
+ </tr>
38
+ <tr align="center">
39
+ <th>Setup</th><th colspan="3">Zero-shot</th><th colspan="3">Zero-shot</th>
40
+ </tr>
41
+ <tr align="center">
42
+ <td>Metric</td><td>R@1</td><td>R@5</td><td>R@10</td><td>R@1</td><td>R@5</td><td>R@10</td>
43
+ </tr>
44
+ <tr align="center">
45
+ <td width="120%">Taiyi-326M</td><td>53.8</td><td>79.9</td><td>86.6</td><td>64.0</td><td>90.4</td><td>96.1</td><td>78.47</td>
46
+ </tr>
47
+ <tr align="center">
48
+ <td width="120%">AltCLIP</td><td>50.7</td><td>75.4</td><td>83.1</td><td>73.4</td><td>92.8</td><td>96.9</td><td>78.72</td>
49
+ </tr>
50
+ <tr align="center">
51
+ <td width="120%">Wukong</td><td>51.9</td><td>78.6</td><td>85.9</td><td>75</td><td>94.4</td><td>97.7</td><td>80.57</td>
52
+ </tr>
53
+ <tr align="center">
54
+ <td width="120%">R2D2</td><td>42.6</td><td>69.5</td><td>78.6</td><td>63.0</td><td>90.1</td><td>96.4</td><td>73.37</td>
55
+ </tr>
56
+ <tr align="center">
57
+ <td width="120%">CN-CLIP</td><td>68.1</td><td>89.7</td><td>94.5</td><td>80.2</td><td>96.6</td><td>98.2</td><td>87.87</td>
58
+ </tr>
59
+ <tr align="center">
60
+ <td width="120%">SkyCLIP</td><td>58.8</td><td>82.6</td><td>89.6</td><td>78.8</td><td>96.1</td><td>98.3</td><td>84.04</td>
61
+ </tr>
62
+ </table>
63
+ <br>
64
+
65
+ ### SkyCLIP计算图文相似度
66
+ ```py
67
+ from PIL import Image
68
+ import requests
69
+ import clip
70
+ import torch
71
+ from transformers import BertTokenizer
72
+ from transformers import CLIPProcessor, CLIPModel, CLIPTextModel
73
+ import numpy as np
74
+
75
+ query_texts = ['一个人', '一辆汽车', '两个男人', '两个女人'] # 这里是输入提示词,可以随意替换。
76
+ # 加载SkyCLIP 中英文双语 text_encoder
77
+ text_tokenizer = BertTokenizer.from_pretrained("./tokenizer")
78
+ text_encoder = CLIPTextModel.from_pretrained("./text_encoder").eval()
79
+ text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']
80
+
81
+ url = "http://images.cocodataset.org/val2017/000000040083.jpg" #这里可以换成任意图片的url
82
+ # 加载CLIP的image encoder
83
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
84
+ clip_text_proj = clip_model.text_projection
85
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
86
+ image = processor(images=Image.open(requests.get(url, stream=True).raw), return_tensors="pt")
87
+
88
+ with torch.no_grad():
89
+ image_features = clip_model.get_image_features(**image)
90
+ text_features = text_encoder(text)[0]
91
+ # sep_token对应于openai-clip的eot_token
92
+ sep_index = torch.nonzero(text == student_tokenizer.sep_token_id)
93
+ text_features = text_features[torch.arange(text.shape[0]), sep_index[:, 1]]
94
+ # 乘text投影矩阵
95
+ text_features = clip_text_proj(text_features)
96
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
97
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
98
+ # 计算余弦相似度 logit_scale是尺度系数
99
+ logit_scale = clip_model.logit_scale.exp()
100
+ logits_per_image = logit_scale * image_features @ text_features.t()
101
+ logits_per_text = logits_per_image.t()
102
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
103
+ print(np.around(probs, 3))
104
+
105
+ ```
106
+
107
+
108
+ ### 扩散模型 Diffusion Model
109
+ 我们的数据采用了筛选过的Laion数据集作为训练数据,同时在文本前面加上了 'sai-v1 art' 作为tag使模型能够更快速的学习到我们想要的风格及质量。
110
+ 预训练模型采用了[stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) 作为预训练,使用了16块A100训练了50个小时。
111
+ 目前模型还在持续优化中,后续会有更稳定的模型更新
112
+
113
+ # 效果展示
114
+
115
+ ### 中文
116
+ 机械狗
117
+ ![](results/1.png)
118
+
119
+ 城堡 大海 夕阳 宫崎骏动画
120
+ ![](results/2.png)
121
+
122
+ 花落知多少
123
+ ![](results/3.png)
124
+
125
+ 半鸡半人,强壮
126
+ ![](results/4.png)
127
+
128
+ 鸡你太美
129
+ ![](results/5.png)
130
+
131
+ ## 测试用例
132
+
133
+ 模型下载地址 [SkyPaint-v1.0](https://sai-hk.oss-cn-hongkong.aliyuncs.com/zb/skypaint-v-1.0.zip?OSSAccessKeyId=LTAI5tHuxqp63n5qw5eeB6Ji&Expires=1673528832&Signature=4PTeknRoXuHWmeQHXqgu8kB0q%2Bw%3D)
134
+
135
+ ```py
136
+ from diffusers import StableDiffusionPipeline
137
+
138
+ device = 'cuda'
139
+ pipe = StableDiffusionPipeline.from_pretrained("path_to_our_model").to(device)
140
+
141
+ prompts = [
142
+ '机械狗',
143
+ '城堡 大海 夕阳 宫崎骏动画',
144
+ '花落知多少',
145
+ '鸡你太美',
146
+ ]
147
+
148
+ for prompt in prompts:
149
+ prompt = 'sai-v1 art, ' + prompt
150
+ image = pipe(prompt).images[0]
151
+ image.save("%s.jpg" % prompt)
152
+ ```
153
+
154
+ # License
155
+ - [MIT License](LICENSE)
156
+ - [CreativeML Open RAIL-M](LICENSE-MODEL)