File size: 7,644 Bytes
2df5d25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
# 配置类定义
class Config:
def __init__(self):
# 模型架构参数
self.hidden_size = 768
self.num_attention_heads = 12
self.num_hidden_layers = 12
self.intermediate_size = 3072
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
# 图像相关
self.image_size = 224
self.image_channels = 3
self.patch_size = 16
# 文本相关
self.max_position_embeddings = 512
self.vocab_size = 30522
self.type_vocab_size = 2
# 语音相关
self.audio_sample_rate = 16000
self.audio_frame_size = 1024
self.audio_hop_size = 512
# 任务相关
self.enable_vqa = True
self.enable_caption = True
self.enable_retrieval = True
self.enable_asr = True # 语音识别
self.enable_realtime_asr = True # 实时语音识别
# 训练相关
self.batch_size = 32
self.learning_rate = 1e-4
self.weight_decay = 0.01
self.warmup_steps = 10000
self.max_steps = 100000
# 模型相关类定义
class ImageEncoder(nn.Module):
def __init__(self, config):
super(ImageEncoder, self).__init__()
self.config = config
self.encoder_layer = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Flatten(),
nn.Linear(64 * 111 * 111, config.hidden_size)
)
def forward(self, image):
image_features = self.encoder_layer(image)
return image_features
class TextEncoder(nn.Module):
def __init__(self, config):
super(TextEncoder, self).__init__()
self.config = config
self.transformer_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(
self.transformer_layer,
num_layers=config.num_hidden_layers
)
def forward(self, text):
text_features = self.transformer_encoder(text).mean(dim=1)
return text_features
class AudioEncoder(nn.Module):
def __init__(self, config):
super(AudioEncoder, self).__init__()
self.config = config
self.encoder_layer = nn.Sequential(
nn.Linear(config.audio_sample_rate, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, config.hidden_size)
)
def forward(self, audio):
audio_features = self.encoder_layer(audio)
return audio_features
class FusionLayer(nn.Module):
def __init__(self, config):
super(FusionLayer, self).__init__()
self.config = config
self.fusion_layer = nn.Linear(config.hidden_size * 3, config.hidden_size)
def forward(self, image_features, text_features, audio_features):
fused_features = torch.cat((image_features, text_features, audio_features), dim=1)
fused_features = self.fusion_layer(fused_features)
return fused_features
class VQALayer(nn.Module):
def __init__(self, config):
super(VQALayer, self).__init__()
self.config = config
self.vqa_layer = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, fused_features):
vqa_output = self.vqa_layer(fused_features)
return vqa_output
class CaptionLayer(nn.Module):
def __init__(self, config):
super(CaptionLayer, self).__init__()
self.config = config
self.caption_layer = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, fused_features):
caption_output = self.caption_layer(fused_features)
return caption_output
class RetrievalLayer(nn.Module):
def __init__(self, config):
super(RetrievalLayer, self).__init__()
self.config = config
self.retrieval_layer = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, fused_features):
retrieval_output = self.retrieval_layer(fused_features)
return retrieval_output
class ASRLayer(nn.Module):
def __init__(self, config):
super(ASRLayer, self).__init__()
self.config = config
self.asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, fused_features):
asr_output = self.asr_layer(fused_features)
return asr_output
class RealtimeASRLayer(nn.Module):
def __init__(self, config):
super(RealtimeASRLayer, self).__init__()
self.config = config
self.realtime_asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, fused_features):
realtime_asr_output = self.realtime_asr_layer(fused_features)
return realtime_asr_output
class TextOutputLayer(nn.Module):
def __init__(self, config):
super(TextOutputLayer, self).__init__()
self.config = config
self.text_output_layer = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, fused_features):
text_output = self.text_output_layer(fused_features)
return text_output
# 主模型定义
class AutoModel(nn.Module):
def __init__(self, config):
super(AutoModel, self).__init__()
self.config = config
self.image_encoder = ImageEncoder(config)
self.text_encoder = TextEncoder(config)
self.audio_encoder = AudioEncoder(config)
self.fusion_layer = FusionLayer(config)
self.vqa_layer = VQALayer(config)
self.caption_layer = CaptionLayer(config)
self.retrieval_layer = RetrievalLayer(config)
self.asr_layer = ASRLayer(config)
self.realtime_asr_layer = RealtimeASRLayer(config)
self.text_output_layer = TextOutputLayer(config)
def forward(self, image, text, audio):
image_features = self.image_encoder(image)
text_features = self.text_encoder(text)
audio_features = self.audio_encoder(audio)
fused_features = self.fusion_layer(image_features, text_features, audio_features)
vqa_output = self.vqa_layer(fused_features)
caption_output = self.caption_layer(fused_features)
retrieval_output = self.retrieval_layer(fused_features)
asr_output = self.asr_layer(fused_features)
realtime_asr_output = self.realtime_asr_layer(fused_features)
text_output = self.text_output_layer(fused_features)
return vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output, text_output
# 测试代码
config = Config()
model = AutoModel(config)
image = torch.randn(1, 3, 224, 224)
text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
audio = torch.randn(1, config.audio_sample_rate)
vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output, text_output = model(image, text, audio)
# 输出结果
print("VQA output shape:", vqa_output.shape)
print("Caption output shape:", caption_output.shape)
print("Retrieval output shape:", retrieval_output.shape)
print("ASR output shape:", asr_output.shape)
print("Realtime ASR output shape:", realtime_asr_output.shape)
print("Text output shape:", text_output.shape)
# 打印总参数数量
total_params = sum(p.numel() for p in model.parameters())
print(f"\n总参数数量: {total_params}")
# 保存模型权重
save_path = "save.pth"
torch.save(model.state_dict(), save_path)
print(f"模型权重已保存到: {save_path}")
|