{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51", "metadata": {}, "outputs": [], "source": [ "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n", "GRAPH_LENGTH = 512\n", "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new3\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "bba6e6db-4b79-4461-ba13-75fd41019358", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA 可用: True\n", "GPU 数量: 1\n", "当前 GPU: 0\n", "GPU 名称: NVIDIA A100 80GB PCIe\n" ] } ], "source": [ "# !pip install transformers accelerate datasets\n", "# !pip install galora\n", "# !pip install huggingface_hub\n", "import torch\n", "print(\"CUDA 可用:\", torch.cuda.is_available())\n", "print(\"GPU 数量:\", torch.cuda.device_count())\n", "print(\"当前 GPU:\", torch.cuda.current_device())\n", "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))" ] }, { "cell_type": "code", "execution_count": 3, "id": "ef5551ca-89e2-4488-8e68-1c8d964de039", "metadata": {}, "outputs": [], "source": [ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度" ] }, { "cell_type": "code", "execution_count": 4, "id": "8e283f49-fde4-46e2-9891-dbc304058f0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_data 重新加载成功,数据量: 12384\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n", "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.19.7" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /workspace/wandb/run-20250304_134403-e0v0giuw" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run experi030403 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/yifang_zhao/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/yifang_zhao/huggingface/runs/e0v0giuw" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [5310/5310 1:33:59, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
505.319300
1003.641300
1501.521800
2001.027500
2500.922400
3000.866900
3500.800500
4000.721600
4500.740400
5000.737000
5500.713500
6000.747000
6500.869500
7001.473300
7500.753000
8000.741300
8500.751400
9000.787600
9500.783200
10000.780200
10501.012900
11001.411700
11501.536400
12000.853800
12500.756500
13000.750800
13500.747400
14000.844400
14500.858400
15001.053400
15501.591600
16001.498900
16501.471700
17001.221100
17501.802300
18001.826000
18501.857300
19001.561800
19501.398800
20001.398900
20501.381600
21000.890300
21500.763700
22000.753100
22500.745500
23001.186100
23500.862000
24001.024600
24501.028400
25001.008500
25500.942800
26000.849700
26500.771400
27000.794100
27500.819200
28000.937500
28501.064500
29001.189300
29501.071100
30001.003300
30501.073900
31001.043100
31501.282600
32002.145400
32501.925800
33002.005600
33502.122600
34002.163000
34502.046600
35002.152200
35502.151700
36005.394900
36504.677800
37004.122200
37503.710200
38003.350800
38503.126300
39002.988700
39502.872000
40002.848200
40502.823900
41002.781200
41502.735000
42002.725900
42502.644400
43002.700000
43502.650100
44002.704500
44502.596700
45002.510500
45502.515800
46002.498100
46502.458900
47002.449700
47502.425000
48002.362300
48502.232000
49002.361500
49502.302300
50002.333900
50502.367200
51002.288300
51502.426100
52002.344100
52502.283500
53002.296500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n" ] }, { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3/commit/b9472b66316be8654c6f7c173fa4561889bd3446', commit_message='End of training', commit_description='', oid='b9472b66316be8654c6f7c173fa4561889bd3446', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new3', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new3'), pr_revision=None, pr_num=None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import json\n", "import torch\n", "import os\n", "from transformers import AutoTokenizer\n", "train_data = torch.load(\"train_data.pt\",weights_only=False)\n", "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", "if 'train_data' not in globals():\n", " train_data_path = \"train_data.pt\"\n", " \n", " if os.path.exists(train_data_path): #确保文件存在\n", " train_data = torch.load(train_data_path, weights_only=False)\n", " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", " else:\n", " print(f\"未找到 {train_data_path},请检查路径!\")\n", " exit()\n", "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n", "if \"MODEL_NAME\" not in globals():\n", " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n", "\n", "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from torch.utils.data import Dataset\n", "\n", "class GraphDataset(Dataset):\n", " def __init__(self, data):\n", " self.data = data\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " sample = self.data[idx]\n", " return {\n", " \"input_ids\": sample[\"input_ids\"],\n", " \"attention_mask\": sample[\"attention_mask\"],\n", " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n", " \"labels\": sample[\"labels\"],\n", " }\n", "\n", "from transformers import AutoModelForCausalLM, AutoConfig\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, pretrained_model_name_or_path, num_heads=8):\n", " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n", " \n", " # ✅ 载入 LLM 预训练模型\n", " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n", "\n", " # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n", " self.linear1 = nn.Linear(512, self.config.hidden_size)\n", "\n", " # ✅ 2. 多头注意力层\n", " self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n", "\n", " # ✅ 3. 线性变换\n", " self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n", "\n", " # ✅ 4. 残差连接 + LayerNorm\n", " self.norm = nn.LayerNorm(self.config.hidden_size)\n", " \n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 1. 线性变换 `graph_embedding`\n", " graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n", "\n", " # ✅ 2. 多头注意力计算(自注意力机制)\n", " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n", " \n", " # ✅ 3. 线性层 + 残差连接\n", " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n", "\n", " # ✅ 4. 归一化\n", " graph_embedding_token = self.norm(graph_embedding_token)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", "from transformers import Trainer\n", "\n", "class GraphTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n", " input_ids = inputs[\"input_ids\"]\n", " attention_mask = inputs[\"attention_mask\"]\n", " labels = inputs[\"labels\"]\n", " graph_embedding = inputs.get(\"graph_embedding\", None) \n", "\n", " if graph_embedding is not None:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " graph_embedding=graph_embedding, \n", " )\n", " else:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " loss = outputs.loss\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "\n", "from transformers import AutoConfig\n", "\n", "# ✅ 载入微调模型\n", "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n", "\n", "# ✅ 训练参数\n", "training_args = TrainingArguments(\n", " output_dir=\"./results3\",\n", " per_device_train_batch_size=7,\n", " eval_strategy=\"no\",\n", " save_strategy=\"steps\",\n", " save_steps=3000,\n", " logging_steps=50,\n", " bf16=True,\n", " optim=\"galore_adamw\",\n", " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n", " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n", " warmup_steps=1000,\n", " num_train_epochs=3,\n", " push_to_hub=True,\n", " hub_model_id=HF_NAME,\n", " hub_strategy=\"every_save\",\n", " run_name = \"experi030403\"\n", ")\n", "\n", "\n", "# ✅ 转换 `train_data` 为 `Dataset`\n", "train_dataset = GraphDataset(train_data)\n", "\n", "# ✅ 训练\n", "trainer = GraphTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", ")\n", "\n", "trainer.train()\n", "trainer.save_model(\"/workspace/model3\")\n", "trainer.push_to_hub()\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "7a72ac3b-561e-41d3-ae93-99f20acf3188", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora_new2-3000', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora_new2-3000')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from huggingface_hub import HfApi\n", "\n", "api = HfApi()\n", "repo_name = \"r1q1.5_graph_lora-results3\" # 你的模型名称\n", "api.create_repo(repo_name, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 3, "id": "73c434b9-5d58-4819-8526-24aa18ca1010", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8b896f21685e4086b0b59404b2b1a866", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model-00002-of-00002.safetensors: 0%| | 0.00/2.11G [00:00\", \"\") #去掉 \n", " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n", " dialogue_text += f\"{role_token} {content}\\n\"\n", "\n", " tokenized = tokenizer(\n", " dialogue_text,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n", " return_tensors=\"pt\",\n", " )\n", "\n", " input_ids = tokenized[\"input_ids\"].squeeze(0)\n", " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n", "\n", " train_data.append({\n", " \"input_ids\": input_ids,\n", " \"attention_mask\": attention_mask,\n", " \"labels\": input_ids.clone(),\n", " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n", " })\n", "\n", "print(\"🚀 处理后数据条数:\", len(train_data))\n", "print(\"✅ 示例数据:\", train_data[0])\n", "torch.save(train_data, \"train_data.pt\")\n", "print(\"✅ train_data 已保存到 train_data.pt\")\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoConfig\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, pretrained_model_name_or_path, num_heads=8):\n", " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n", " \n", " # ✅ 载入 LLM 预训练模型\n", " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n", "\n", " # ✅ 1. 线性变换,将 `graph_embedding` 从 512 维映射到 `hidden_size`\n", " self.linear1 = nn.Linear(512, self.config.hidden_size)\n", "\n", " # ✅ 2. 多头注意力层\n", " self.multihead_attn = nn.MultiheadAttention(embed_dim=self.config.hidden_size, num_heads=num_heads, batch_first=True)\n", "\n", " # ✅ 3. 线性变换\n", " self.linear2 = nn.Linear(self.config.hidden_size, self.config.hidden_size)\n", "\n", " # ✅ 4. 残差连接 + LayerNorm\n", " self.norm = nn.LayerNorm(self.config.hidden_size)\n", " \n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 1. 线性变换 `graph_embedding`\n", " graph_embedding_token = self.linear1(graph_embedding) # (batch_size, 1, hidden_size)\n", "\n", " # ✅ 2. 多头注意力计算(自注意力机制)\n", " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n", " \n", " # ✅ 3. 线性层 + 残差连接\n", " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (batch_size, 1, hidden_size)\n", "\n", " # ✅ 4. 归一化\n", " graph_embedding_token = self.norm(graph_embedding_token)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", " def generate(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n", " \"\"\"\n", " ✅ 自定义 `generate()` 方法,支持 `graph_embedding`\n", " `input_text`: 需要生成文本的输入\n", " `graph_embedding`: 形状为 (1, 512) 的张量\n", " \"\"\"\n", "\n", " # ✅ 2. 处理 `graph_embedding`\n", " graph_embedding_token = self.linear1(graph_embedding) # (1, 1, hidden_size)\n", " attn_output, _ = self.multihead_attn(graph_embedding_token, graph_embedding_token, graph_embedding_token)\n", " graph_embedding_token = self.linear2(attn_output) + graph_embedding_token # (1, 1, hidden_size)\n", " graph_embedding_token = self.norm(graph_embedding_token)\n", "\n", " # ✅ 3. 获取 Token Embeddings 并拼接\n", " inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n", "\n", " # ✅ 4. 调整 `attention_mask`\n", " if \"attention_mask\" in inputs:\n", " graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n", " attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n", " else:\n", " attention_mask = None\n", "\n", " # ✅ 5. 进行文本生成\n", " with torch.no_grad():\n", " output_ids = self.model.generate(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " max_length=max_length,\n", " temperature=temperature,\n", " top_k=top_k,\n", " top_p=top_p,\n", " num_return_sequences=1\n", " )\n", "\n", " # ✅ 6. 解码输出\n", " generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", " return generated_text\n", "\n", " @classmethod\n", " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n", " # ✅ 1. 调用 `super().from_pretrained()` 加载 LLM\n", " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n", "\n", " # ✅ 2. 初始化 `MLP + MultiheadAttention` 结构\n", " model.linear1 = nn.Linear(512, model.config.hidden_size)\n", " model.multihead_attn = nn.MultiheadAttention(embed_dim=model.config.hidden_size, num_heads=8, batch_first=True)\n", " model.linear2 = nn.Linear(model.config.hidden_size, model.config.hidden_size)\n", " model.norm = nn.LayerNorm(model.config.hidden_size)\n", "\n", " return model" ] }, { "cell_type": "code", "execution_count": 2, "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0b50f0cd6d784f598cc64a40cff40f38", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\", \"\").strip()\n", "\n", "from transformers import AutoTokenizer\n", "\n", "# ✅ 输入文本\n", "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "GRAPH_LENGTH = 512\n", "max_seq_length = 1100 + GRAPH_LENGTH\n", "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n", "\n", "input_ids = inputs[\"input_ids\"]\n", "attention_mask = inputs[\"attention_mask\"]\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc", "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'Tensor' object has no attribute 'update'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[8], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1982\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 1979\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# Pull this out first, we only use it for stopping criteria\u001b[39;00m\n\u001b[1;32m 1980\u001b[0m assistant_tokenizer \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124massistant_tokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# only used for assisted generation\u001b[39;00m\n\u001b[0;32m-> 1982\u001b[0m generation_config, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_generation_config\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1983\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_model_kwargs(model_kwargs\u001b[38;5;241m.\u001b[39mcopy())\n\u001b[1;32m 1984\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_validate_assistant(assistant_model, tokenizer, assistant_tokenizer)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1549\u001b[0m, in \u001b[0;36mGenerationMixin._prepare_generation_config\u001b[0;34m(self, generation_config, **kwargs)\u001b[0m\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 1548\u001b[0m generation_config \u001b[38;5;241m=\u001b[39m copy\u001b[38;5;241m.\u001b[39mdeepcopy(generation_config)\n\u001b[0;32m-> 1549\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1550\u001b[0m \u001b[38;5;66;03m# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model\u001b[39;00m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m using_model_generation_config:\n", "\u001b[0;31mAttributeError\u001b[0m: 'Tensor' object has no attribute 'update'" ] } ], "source": [ "generated_text = model.generate(inputs, graph_embedding)" ] }, { "cell_type": "code", "execution_count": 5, "id": "62f40327-f102-4259-80a5-8761d5d7d3c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n", " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n", " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n", " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n", " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n", " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n", " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n", " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n", " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n", " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n", " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n", " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n", " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n", " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n", " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n", " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n", " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n", " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n", " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n", " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n", " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n", " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n", " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n", " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n", " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n", " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n", " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n", " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n", " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n", " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n", " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n", " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n", " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n", " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n", " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n", " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n", " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n", " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n", " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n", " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n", " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n", " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n", " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n", " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n", " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n", " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n", " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n", " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n", " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n", " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n", " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n", " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n", " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n", " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n", " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n", " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n", " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n", " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n", " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n", " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n", " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n", " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n", " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n", " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n", " device='cuda:0')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph_embedding" ] }, { "cell_type": "code", "execution_count": 15, "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型回复: How\n" ] } ], "source": [ "# ✅ 进行前向传播\n", "with torch.no_grad():\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n", "\n", "# ✅ 提取 logits 并进行贪心解码\n", "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n", "\n", "# ✅ 反向编码为文本\n", "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n", "\n", "print(\"模型回复:\", response_text)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated Response: What are the signal definitions in the Verilog code for the calculator module, and what are their purposes? The Verilog code defines the inputs A, B, and C, and the output Y. A and B are the operands, C is the carry-in, and Y is the result. The purpose of the module is to perform a 2-bit adder, which adds two 2-bit numbers, and the output is the sum. The inputs A and B are the operands, C is the carry-in, and Y is the result. The module is designed to handle the addition operation of two 2-bit numbers, with a carry-in, and a 3-bit output. The implementation involves using logic gates to perform the addition operation, with the sum output connected to the gates. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, involving basic gates and an adder circuit. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is simple, with no need for complex logic gates or delays. The carry-in is used to control whether the carry-out is active or not. The output Y is the result of the addition operation. The implementation is straightforward, with\n" ] } ], "source": [ "max_new_tokens = 500\n", "generated_ids = input_ids.clone()\n", "generated_attention_mask = attention_mask.clone()\n", "for _ in range(max_new_tokens):\n", " # ✅ 计算 logits 并进行生成\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=generated_ids, # (batch_size, seq_len)\n", " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n", " graph_embedding=graph_embedding, # (batch_size, 512)\n", " )\n", "\n", "\n", " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n", " # print(next_token)\n", "\n", "\n", " # ✅ **拼接到已生成序列**\n", " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n", "\n", " # print(generated_ids)\n", "\n", " if next_token.item() == tokenizer.eos_token_id:\n", " break\n", "\n", " generated_attention_mask = torch.cat(\n", " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n", " ) \n", "\n", "# ✅ 解码最终输出\n", "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "print(\"Generated Response:\", generated_text)" ] }, { "cell_type": "code", "execution_count": 10, "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n", " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n", " 525, 862, 9895, 30]], device='cuda:0')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generated_ids" ] }, { "cell_type": "code", "execution_count": null, "id": "87d1396b-4d20-4a76-a092-b26a587a76ac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }