{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51", "metadata": {}, "outputs": [], "source": [ "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n", "GRAPH_LENGTH = 512\n", "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new2\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "bba6e6db-4b79-4461-ba13-75fd41019358", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA 可用: True\n", "GPU 数量: 1\n", "当前 GPU: 0\n", "GPU 名称: NVIDIA A100 80GB PCIe\n" ] } ], "source": [ "# !pip install transformers accelerate datasets\n", "# !pip install galora\n", "# !pip install huggingface_hub\n", "import torch\n", "print(\"CUDA 可用:\", torch.cuda.is_available())\n", "print(\"GPU 数量:\", torch.cuda.device_count())\n", "print(\"当前 GPU:\", torch.cuda.current_device())\n", "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))" ] }, { "cell_type": "code", "execution_count": 4, "id": "ef5551ca-89e2-4488-8e68-1c8d964de039", "metadata": {}, "outputs": [], "source": [ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度" ] }, { "cell_type": "code", "execution_count": 4, "id": "8e283f49-fde4-46e2-9891-dbc304058f0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_data 重新加载成功,数据量: 12384\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n", "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.19.7" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /workspace/wandb/run-20250304_111730-i9v1vlu1" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run experi030402 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/yifang_zhao/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/yifang_zhao/huggingface/runs/i9v1vlu1" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [5310/5310 1:34:08, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
505.319300
1003.641300
1501.521800
2001.027500
2500.922400
3000.866900
3500.800500
4000.721600
4500.740400
5000.737000
5500.713500
6000.747000
6500.869500
7001.473300
7500.753000
8000.741300
8500.751400
9000.787600
9500.783200
10000.780200
10501.012900
11001.411700
11501.536400
12000.853800
12500.756500
13000.750800
13500.747400
14000.844400
14500.858400
15001.053400
15501.591600
16001.498900
16501.471700
17001.221100
17501.802300
18001.826000
18501.857300
19001.561800
19501.398800
20001.398900
20501.381600
21000.890300
21500.763700
22000.753100
22500.745500
23001.186100
23500.862000
24001.024600
24501.028400
25001.008500
25500.942800
26000.849700
26500.771400
27000.794100
27500.819200
28000.937500
28501.064500
29001.189300
29501.071100
30001.003300
30501.073900
31001.043100
31501.282600
32002.145400
32501.925800
33002.005600
33502.122600
34002.163000
34502.046600
35002.152200
35502.151700
36005.394900
36504.677800
37004.122200
37503.710200
38003.350800
38503.126300
39002.988700
39502.872000
40002.848200
40502.823900
41002.781200
41502.735000
42002.725900
42502.644400
43002.700000
43502.650100
44002.704500
44502.596700
45002.510500
45502.515800
46002.498100
46502.458900
47002.449700
47502.425000
48002.362300
48502.232000
49002.361500
49502.302300
50002.333900
50502.367200
51002.288300
51502.426100
52002.344100
52502.283500
53002.296500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n" ] }, { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2/commit/291285a5f2155c79a0da893645d8df9bbca98f63', commit_message='End of training', commit_description='', oid='291285a5f2155c79a0da893645d8df9bbca98f63', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new2', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new2'), pr_revision=None, pr_num=None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import json\n", "import torch\n", "import os\n", "from transformers import AutoTokenizer\n", "train_data = torch.load(\"train_data.pt\",weights_only=False)\n", "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", "if 'train_data' not in globals():\n", " train_data_path = \"train_data.pt\"\n", " \n", " if os.path.exists(train_data_path): #确保文件存在\n", " train_data = torch.load(train_data_path, weights_only=False)\n", " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", " else:\n", " print(f\"未找到 {train_data_path},请检查路径!\")\n", " exit()\n", "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n", "if \"MODEL_NAME\" not in globals():\n", " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n", "\n", "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from torch.utils.data import Dataset\n", "\n", "class GraphDataset(Dataset):\n", " def __init__(self, data):\n", " self.data = data\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " sample = self.data[idx]\n", " return {\n", " \"input_ids\": sample[\"input_ids\"],\n", " \"attention_mask\": sample[\"attention_mask\"],\n", " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n", " \"labels\": sample[\"labels\"],\n", " }\n", "\n", "from transformers import AutoModelForCausalLM, AutoConfig\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, pretrained_model_name_or_path):\n", " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n", " \n", " # ✅ 载入 `MODEL_NAME` 预训练模型\n", " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n", "\n", " \n", " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n", " self.graph_proj = nn.Linear(512, self.config.hidden_size)\n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 变换 graph embedding 到 hidden_size\n", " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", "from transformers import Trainer\n", "\n", "class GraphTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n", " input_ids = inputs[\"input_ids\"]\n", " attention_mask = inputs[\"attention_mask\"]\n", " labels = inputs[\"labels\"]\n", " graph_embedding = inputs.get(\"graph_embedding\", None) \n", "\n", " if graph_embedding is not None:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " graph_embedding=graph_embedding, \n", " )\n", " else:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " loss = outputs.loss\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "\n", "from transformers import AutoConfig\n", "\n", "# ✅ 载入微调模型\n", "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n", "\n", "# # 1. 加载模型的配置\n", "# config = AutoConfig.from_pretrained(MODEL_NAME)\n", "\n", "# # 2. 使用配置创建 GraphAwareLM 实例\n", "# model = GraphAwareLM.from_config(config) \n", "\n", "# pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", "# model.load_state_dict(pretrained_model.state_dict(), strict=False)\n", "\n", "# ✅ 载入修改后的 `GraphAwareLM` 模型\n", "# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n", "# model.config.use_sliding_window_attention = False\n", "\n", "# ✅ 训练参数\n", "training_args = TrainingArguments(\n", " output_dir=\"./results2\",\n", " per_device_train_batch_size=7,\n", " eval_strategy=\"no\",\n", " save_strategy=\"steps\",\n", " save_steps=3000,\n", " logging_steps=50,\n", " bf16=True,\n", " optim=\"galore_adamw\",\n", " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n", " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n", " warmup_steps=1000,\n", " num_train_epochs=3,\n", " push_to_hub=True,\n", " hub_model_id=HF_NAME,\n", " hub_strategy=\"every_save\",\n", " run_name = \"experi030402\"\n", ")\n", "\n", "\n", "# ✅ 转换 `train_data` 为 `Dataset`\n", "train_dataset = GraphDataset(train_data)\n", "\n", "# ✅ 训练\n", "trainer = GraphTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", ")\n", "\n", "trainer.train()\n", "trainer.save_model(\"/workspace/model2\")\n", "trainer.push_to_hub()\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "7a72ac3b-561e-41d3-ae93-99f20acf3188", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RepoUrl('https://huggingface.co/YiFzhao/r1q1.5_graph_lora-wandb', endpoint='https://huggingface.co', repo_type='model', repo_id='YiFzhao/r1q1.5_graph_lora-wandb')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from huggingface_hub import HfApi\n", "\n", "api = HfApi()\n", "repo_name = \"r1q1.5_graph_lora-wandb\" # 你的模型名称\n", "api.create_repo(repo_name, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 6, "id": "73c434b9-5d58-4819-8526-24aa18ca1010", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "727ca342a20348d38a4a1c6d286963e0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "optimizer.pt: 0%| | 0.00/4.32G [00:00\", \"\") #去掉 \n", " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n", " dialogue_text += f\"{role_token} {content}\\n\"\n", "\n", " tokenized = tokenizer(\n", " dialogue_text,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n", " return_tensors=\"pt\",\n", " )\n", "\n", " input_ids = tokenized[\"input_ids\"].squeeze(0)\n", " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n", "\n", " train_data.append({\n", " \"input_ids\": input_ids,\n", " \"attention_mask\": attention_mask,\n", " \"labels\": input_ids.clone(),\n", " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n", " })\n", "\n", "print(\"🚀 处理后数据条数:\", len(train_data))\n", "print(\"✅ 示例数据:\", train_data[0])\n", "torch.save(train_data, \"train_data.pt\")\n", "print(\"✅ train_data 已保存到 train_data.pt\")\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoConfig\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, pretrained_model_name_or_path):\n", " super().__init__(AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).config)\n", " \n", " # ✅ 载入 `MODEL_NAME` 预训练模型\n", " self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)\n", "\n", " \n", " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n", " self.graph_proj = nn.Linear(512, self.config.hidden_size)\n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 变换 graph embedding 到 hidden_size\n", " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", " def generate_with_graph(self, inputs, graph_embedding, max_length=500, temperature=0.7, top_k=50, top_p=0.9):\n", " \"\"\"\n", " ✅ 自定义 `generate()`,支持 `graph_embedding`\n", " `input_text`: 需要生成文本的输入\n", " `graph_embedding`: 形状为 (1, 512) 的张量\n", " \"\"\"\n", " # ✅ 2. 处理 `graph_embedding`\n", " graph_embedding_token = self.graph_proj(graph_embedding) # (1, hidden_size)\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (1, 1, hidden_size)\n", "\n", " # ✅ 3. 获取 Token Embeddings 并拼接\n", " inputs_embeds = self.model.get_input_embeddings()(inputs[\"input_ids\"]) # (1, seq_len, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (1, seq_len+1, hidden_size)\n", "\n", " # ✅ 4. 调整 `attention_mask`\n", " if \"attention_mask\" in inputs:\n", " graph_mask = torch.ones((inputs[\"attention_mask\"].shape[0], 1), device=inputs[\"attention_mask\"].device, dtype=inputs[\"attention_mask\"].dtype)\n", " attention_mask = torch.cat([graph_mask, inputs[\"attention_mask\"]], dim=1) # (1, seq_len+1)\n", " else:\n", " attention_mask = None\n", "\n", " # ✅ 5. 进行文本生成\n", " with torch.no_grad():\n", " output_ids = self.model.generate(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " max_length=max_length,\n", " temperature=temperature,\n", " top_k=top_k,\n", " top_p=top_p,\n", " num_return_sequences=1\n", " )\n", "\n", " # ✅ 6. 解码生成的文本\n", " generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)\n", " return generated_text\n", "\n", " @classmethod\n", " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n", " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n", " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n", " return model" ] }, { "cell_type": "code", "execution_count": 11, "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ad289c5523340f39799ad11e3bc1bb5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\", \"\").strip()\n", "\n", "from transformers import AutoTokenizer\n", "\n", "# ✅ 输入文本\n", "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "GRAPH_LENGTH = 512\n", "max_seq_length = 1100 + GRAPH_LENGTH\n", "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n", "\n", "input_ids = inputs[\"input_ids\"]\n", "attention_mask = inputs[\"attention_mask\"]\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "4bd7493f-ca8d-4c28-914d-95b1c30f8fcc", "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m generated_text \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_with_graph\u001b[49m(inputs, graph_embedding)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mAttributeError\u001b[0m: 'Qwen2ForCausalLM' object has no attribute 'generate_with_graph'" ] } ], "source": [ "generated_text = model.generate_with_graph(inputs, graph_embedding)" ] }, { "cell_type": "code", "execution_count": 5, "id": "62f40327-f102-4259-80a5-8761d5d7d3c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n", " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n", " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n", " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n", " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n", " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n", " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n", " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n", " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n", " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n", " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n", " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n", " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n", " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n", " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n", " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n", " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n", " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n", " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n", " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n", " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n", " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n", " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n", " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n", " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n", " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n", " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n", " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n", " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n", " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n", " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n", " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n", " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n", " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n", " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n", " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n", " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n", " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n", " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n", " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n", " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n", " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n", " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n", " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n", " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n", " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n", " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n", " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n", " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n", " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n", " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n", " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n", " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n", " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n", " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n", " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n", " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n", " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n", " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n", " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n", " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n", " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n", " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n", " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635],\n", " device='cuda:0')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph_embedding" ] }, { "cell_type": "code", "execution_count": 15, "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型回复: How\n" ] } ], "source": [ "# ✅ 进行前向传播\n", "with torch.no_grad():\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n", "\n", "# ✅ 提取 logits 并进行贪心解码\n", "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n", "\n", "# ✅ 反向编码为文本\n", "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n", "\n", "print(\"模型回复:\", response_text)" ] }, { "cell_type": "code", "execution_count": 17, "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated Response: Is there any sequential logic in the module, and if so, how is it handled? `data` is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit input, and the output is the output of the `data` is a 1-bit data, and the output is the output of the `data` is a 1-bit\n" ] } ], "source": [ "max_new_tokens = 1024\n", "generated_ids = input_ids.clone()\n", "generated_attention_mask = attention_mask.clone()\n", "for _ in range(max_new_tokens):\n", " # ✅ 计算 logits 并进行生成\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=generated_ids, # (batch_size, seq_len)\n", " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n", " graph_embedding=graph_embedding, # (batch_size, 512)\n", " )\n", "\n", "\n", " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n", " # print(next_token)\n", "\n", "\n", " # ✅ **拼接到已生成序列**\n", " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n", "\n", " # print(generated_ids)\n", "\n", " if next_token.item() == tokenizer.eos_token_id:\n", " break\n", "\n", " generated_attention_mask = torch.cat(\n", " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n", " ) \n", "\n", "# ✅ 解码最终输出\n", "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "print(\"Generated Response:\", generated_text)" ] }, { "cell_type": "code", "execution_count": 10, "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n", " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n", " 525, 862, 9895, 30]], device='cuda:0')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generated_ids" ] }, { "cell_type": "code", "execution_count": null, "id": "87d1396b-4d20-4a76-a092-b26a587a76ac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }