zeroMN commited on
Commit
0a66d43
·
verified ·
1 Parent(s): 75582ce

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -1
README.md CHANGED
@@ -69,6 +69,8 @@ This model, named `Evolutionary Multi-Modal Model`, is a multimodal transformer
69
 
70
  ### Direct Use
71
  ```python
 
 
72
  git clone https://huggingface.co/zeroMN/SHMT.git
73
  ```
74
  ### Downstream Use
@@ -90,5 +92,74 @@ Users (both direct and downstream) should be made aware of the following risks,
90
 
91
  ## How to Get Started with the Model
92
  ```python
93
- git clone https://huggingface.co/zeroMN/SHMT.git
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  ```
 
69
 
70
  ### Direct Use
71
  ```python
72
+ git lfs install
73
+
74
  git clone https://huggingface.co/zeroMN/SHMT.git
75
  ```
76
  ### Downstream Use
 
92
 
93
  ## How to Get Started with the Model
94
  ```python
95
+ import os
96
+ import torch
97
+ import torch.nn as nn
98
+ import numpy as np
99
+ import random
100
+ from transformers import (
101
+ BartForConditionalGeneration,
102
+ AutoModelForCausalLM,
103
+ BertModel,
104
+ Wav2Vec2Model,
105
+ CLIPModel,
106
+ AutoTokenizer
107
+ )
108
+
109
+ class MultiModalModel(nn.Module):
110
+ def __init__(self):
111
+ super(MultiModalModel, self).__init__()
112
+ # 初始化子模型
113
+ self.text_generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
114
+ self.code_generator = AutoModelForCausalLM.from_pretrained('gpt2')
115
+ self.nlp_encoder = BertModel.from_pretrained('bert-base-uncased')
116
+ self.speech_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
117
+ self.vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
118
+
119
+ # 初始化分词器和处理器
120
+ self.text_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
121
+ self.code_tokenizer = AutoTokenizer.from_pretrained('gpt2')
122
+ self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
123
+ self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
124
+ self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
125
+
126
+ def forward(self, task, inputs):
127
+ if task == 'text_generation':
128
+ attention_mask = inputs.get('attention_mask')
129
+ outputs = self.text_generator.generate(
130
+ inputs['input_ids'],
131
+ max_new_tokens=100,
132
+ pad_token_id=self.text_tokenizer.eos_token_id,
133
+ attention_mask=attention_mask,
134
+ top_p=0.9,
135
+ top_k=50,
136
+ temperature=0.8,
137
+ do_sample=True
138
+ )
139
+ return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
140
+ elif task == 'code_generation':
141
+ attention_mask = inputs.get('attention_mask')
142
+ outputs = self.code_generator.generate(
143
+ inputs['input_ids'],
144
+ max_new_tokens=50,
145
+ pad_token_id=self.code_tokenizer.eos_token_id,
146
+ attention_mask=attention_mask,
147
+ top_p=0.95,
148
+ top_k=50,
149
+ temperature=1.2,
150
+ do_sample=True
151
+ )
152
+ return self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+ # 添加其他任务的逻辑...
154
+
155
+ # 计算模型参数数量的函数
156
+ def count_parameters(model):
157
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
158
+
159
+ # 初始化模型
160
+ model = MultiModalModel()
161
+
162
+ # 计算并打印模型参数数量
163
+ total_params = count_parameters(model)
164
+ print(f"模型总参数数量: {total_params}")
165
  ```