zeroMN commited on
Commit
7d9df81
·
verified ·
1 Parent(s): 378f3df

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +69 -67
main.py CHANGED
@@ -1,67 +1,69 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- import random
6
- from transformers import (
7
- BartForConditionalGeneration,
8
- AutoModelForCausalLM,
9
- BertModel,
10
- Wav2Vec2Model,
11
- CLIPModel,
12
- AutoTokenizer
13
- )
14
-
15
- class MultiModalModel(nn.Module):
16
- def __init__(self):
17
- super(MultiModalModel, self).__init__()
18
- # 初始化子模型
19
- self.text_generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
20
- self.code_generator = AutoModelForCausalLM.from_pretrained('gpt2')
21
- self.nlp_encoder = BertModel.from_pretrained('bert-base-uncased')
22
- self.speech_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
23
- self.vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
24
-
25
- # 初始化分词器和处理器
26
- self.text_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
27
- self.code_tokenizer = AutoTokenizer.from_pretrained('gpt2')
28
- self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
29
- self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
30
- self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
31
-
32
- def forward(self, task, inputs):
33
- if task == 'text_generation':
34
- # 确保 attention_mask 在 inputs 中
35
- attention_mask = inputs.get('attention_mask')
36
- print("输入数据:", inputs)
37
- outputs = self.text_generator.generate(
38
- inputs['input_ids'],
39
- max_new_tokens=100, # 增加生成的最大新令牌数
40
- pad_token_id=self.text_tokenizer.eos_token_id,
41
- attention_mask=attention_mask,
42
- top_p=0.9, # 调整 top_p 值
43
- top_k=50, # 保持 top_k 值
44
- temperature=0.8, # 调整 temperature 值
45
- do_sample=True
46
- )
47
- print("生成的输出:", outputs)
48
- return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
- # 根据需要添加其他任务的逻辑...
50
-
51
- # 主函数
52
- if __name__ == "__main__":
53
- # 初始化模型
54
- model = MultiModalModel()
55
-
56
- # 示例任务和输入数据
57
- task = "text_generation"
58
- input_text = "This is a sample input."
59
- tokenizer = model.text_tokenizer
60
- inputs = tokenizer(input_text, return_tensors='pt')
61
-
62
- # 添加 attention_mask 键值对
63
- inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
64
-
65
- # 模型推理
66
- result = model(task, inputs)
67
- print("最终输出结果:", result)
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import random
6
+ from transformers import (
7
+ BartForConditionalGeneration,
8
+ AutoModelForCausalLM,
9
+ BertModel,
10
+ Wav2Vec2Model,
11
+ CLIPModel,
12
+ AutoTokenizer
13
+ )
14
+
15
+ class MultiModalModel(nn.Module):
16
+ def __init__(self):
17
+ super(MultiModalModel, self).__init__()
18
+ # 初始化子模型
19
+ self.text_generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
20
+ self.code_generator = AutoModelForCausalLM.from_pretrained('gpt2')
21
+ self.nlp_encoder = BertModel.from_pretrained('bert-base-uncased')
22
+ self.speech_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
23
+ self.vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
24
+
25
+ # 初始化分词器和处理器
26
+ self.text_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
27
+ self.code_tokenizer = AutoTokenizer.from_pretrained('gpt2')
28
+ self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
29
+ self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
30
+ self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
31
+
32
+ def forward(self, task, inputs):
33
+ if task == 'text_generation':
34
+ # 确保 attention_mask 在 inputs 中
35
+ attention_mask = inputs.get('attention_mask')
36
+ print("输入数据:", inputs)
37
+ outputs = self.text_generator.generate(
38
+ inputs['input_ids'],
39
+ max_new_tokens=100, # 增加生成的最大新令牌数
40
+ pad_token_id=self.text_tokenizer.eos_token_id,
41
+ attention_mask=attention_mask,
42
+ top_p=0.9, # 调整 top_p 值
43
+ top_k=50, # 保持 top_k 值
44
+ temperature=0.8, # 调整 temperature 值
45
+ do_sample=True
46
+ )
47
+ print("生成的输出:", outputs)
48
+ return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ # 根据需要添加其他任务的逻辑...
50
+
51
+ # 主函数
52
+ if __name__ == "__main__":
53
+ # 初始化模型
54
+ model = MultiModalModel()
55
+
56
+ # 示例任务和输入数据
57
+ task = "text_generation"
58
+ input_text = "This is a sample input."
59
+ tokenizer = model.text_tokenizer
60
+ inputs = tokenizer(input_text, return_tensors='pt')
61
+
62
+ # 添加 attention_mask 键值对
63
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
64
+
65
+ # 模型推理
66
+ result = model(task, inputs)
67
+ print("最终输出结果:", result)
68
+
69
+ trust_remote_code=True