{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "fa17529d-eaa7-473e-9d2d-cc05a0120a51", "metadata": {}, "outputs": [], "source": [ "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" \n", "GRAPH_LENGTH = 512\n", "HF_NAME = \"KSU-HW-SEC/r1q1.5_graph_lora_new\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "bba6e6db-4b79-4461-ba13-75fd41019358", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA 可用: True\n", "GPU 数量: 1\n", "当前 GPU: 0\n", "GPU 名称: NVIDIA A100 80GB PCIe\n" ] } ], "source": [ "# !pip install transformers accelerate datasets\n", "# !pip install galora\n", "# !pip install huggingface_hub\n", "import torch\n", "print(\"CUDA 可用:\", torch.cuda.is_available())\n", "print(\"GPU 数量:\", torch.cuda.device_count())\n", "print(\"当前 GPU:\", torch.cuda.current_device())\n", "print(\"GPU 名称:\", torch.cuda.get_device_name(torch.cuda.current_device()))" ] }, { "cell_type": "code", "execution_count": 3, "id": "ef5551ca-89e2-4488-8e68-1c8d964de039", "metadata": {}, "outputs": [], "source": [ "max_seq_length = 1100 + GRAPH_LENGTH # 最大序列长度" ] }, { "cell_type": "code", "execution_count": 4, "id": "8e283f49-fde4-46e2-9891-dbc304058f0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_data 重新加载成功,数据量: 12384\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n", "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:48: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.19.7" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /workspace/wandb/run-20250304_081255-v0v96nik" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run experi0304 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/yifang_zhao/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/yifang_zhao/huggingface/runs/v0v96nik" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [5310/5310 1:23:11, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
505.349900
1005.305900
1504.849500
2003.910800
2503.325600
3003.144900
3502.904200
4002.082100
4501.214300
5001.011600
5500.889300
6000.907300
6501.190400
7001.889100
7504.505600
8006.402800
8506.479300
9007.337900
9508.937600
10008.938700
10508.860100
11008.693600
11509.234000
12009.347500
12508.010300
13005.952900
13505.205900
14004.969600
14504.884600
15004.934200
15505.156900
16005.115500
16505.373600
17004.481800
17503.957000
18003.092500
18501.791000
19001.934400
19502.176800
20002.112400
20502.127900
21002.390200
21503.091400
22003.959500
22503.905000
23003.777500
23503.282900
24002.630300
24503.705000
25004.266300
25504.285300
26004.634000
26504.474700
27003.591300
27502.486800
28001.911800
28502.088100
29002.015400
29501.988500
30001.976900
30502.097700
31001.987400
31502.065000
32002.112100
32502.075300
33001.983300
33502.181900
34002.446500
34502.434200
35002.357000
35502.157400
36001.992900
36502.018400
37002.010200
37502.009500
38002.034900
38502.038800
39002.007600
39501.983200
40002.005300
40502.014900
41002.018100
41502.033900
42002.024600
42501.995300
43002.018000
43501.998300
44002.032800
44501.985900
45001.967700
45501.989400
46002.004700
46502.005800
47002.014400
47502.009200
48002.002200
48501.914300
49002.016900
49501.972900
50002.010300
50502.046600
51001.993900
51502.084500
52002.011900
52501.996500
53001.997900

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n" ] }, { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new/commit/231f89403dca9aa67966e4f321e62bdb41076960', commit_message='End of training', commit_description='', oid='231f89403dca9aa67966e4f321e62bdb41076960', pr_url=None, repo_url=RepoUrl('https://huggingface.co/KSU-HW-SEC/r1q1.5_graph_lora_new', endpoint='https://huggingface.co', repo_type='model', repo_id='KSU-HW-SEC/r1q1.5_graph_lora_new'), pr_revision=None, pr_num=None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import json\n", "import torch\n", "import os\n", "from transformers import AutoTokenizer\n", "train_data = torch.load(\"train_data.pt\",weights_only=False)\n", "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", "if 'train_data' not in globals():\n", " train_data_path = \"train_data.pt\"\n", " \n", " if os.path.exists(train_data_path): #确保文件存在\n", " train_data = torch.load(train_data_path, weights_only=False)\n", " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", " else:\n", " print(f\"未找到 {train_data_path},请检查路径!\")\n", " exit()\n", "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n", "if \"MODEL_NAME\" not in globals():\n", " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n", "\n", "# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from torch.utils.data import Dataset\n", "\n", "class GraphDataset(Dataset):\n", " def __init__(self, data):\n", " self.data = data\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " sample = self.data[idx]\n", " return {\n", " \"input_ids\": sample[\"input_ids\"],\n", " \"attention_mask\": sample[\"attention_mask\"],\n", " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n", " \"labels\": sample[\"labels\"],\n", " }\n", "\n", "from transformers import AutoModelForCausalLM, AutoConfig\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, config):\n", " super().__init__(config)\n", "\n", " # self.model = AutoModelForCausalLM.from_config(config)\n", " \n", " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n", " self.graph_proj = nn.Linear(512, config.hidden_size)\n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 变换 graph embedding 到 hidden_size\n", " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", "from transformers import Trainer\n", "\n", "class GraphTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n", " input_ids = inputs[\"input_ids\"]\n", " attention_mask = inputs[\"attention_mask\"]\n", " labels = inputs[\"labels\"]\n", " graph_embedding = inputs.get(\"graph_embedding\", None) \n", "\n", " if graph_embedding is not None:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " graph_embedding=graph_embedding, \n", " )\n", " else:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " loss = outputs.loss\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "\n", "from transformers import AutoConfig\n", "\n", "# 1. 加载模型的配置\n", "config = AutoConfig.from_pretrained(MODEL_NAME)\n", "\n", "# 2. 使用配置创建 GraphAwareLM 实例\n", "model = GraphAwareLM.from_config(config) \n", "\n", "pretrained_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", "model.load_state_dict(pretrained_model.state_dict(), strict=False)\n", "\n", "# ✅ 载入修改后的 `GraphAwareLM` 模型\n", "# model = GraphAwareLM.from_pretrained(MODEL_NAME)\n", "# model.config.use_sliding_window_attention = False\n", "\n", "# ✅ 训练参数\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\",\n", " per_device_train_batch_size=7,\n", " eval_strategy=\"no\",\n", " save_strategy=\"steps\",\n", " save_steps=3000,\n", " logging_steps=50,\n", " bf16=True,\n", " optim=\"galore_adamw\",\n", " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n", " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n", " warmup_steps=1000,\n", " num_train_epochs=3,\n", " push_to_hub=True,\n", " hub_model_id=HF_NAME,\n", " hub_strategy=\"every_save\",\n", " run_name = \"experi0304\"\n", ")\n", "\n", "\n", "# ✅ 转换 `train_data` 为 `Dataset`\n", "train_dataset = GraphDataset(train_data)\n", "\n", "# ✅ 训练\n", "trainer = GraphTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", ")\n", "\n", "trainer.train()\n", "trainer.save_model(\"/workspace/model\")\n", "trainer.push_to_hub()\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "8d2ebf87-402e-444d-8599-96c313f1b7fa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 处理后数据条数: 12384\n", "✅ 示例数据: {'input_ids': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'attention_mask': tensor([0, 0, 0, ..., 1, 1, 1]), 'labels': tensor([151643, 151643, 151643, ..., 1493, 7525, 624]), 'graph_embedding': tensor([-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n", " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n", " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n", " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n", " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n", " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n", " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n", " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n", " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n", " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n", " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n", " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n", " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n", " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n", " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n", " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n", " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n", " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n", " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n", " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n", " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n", " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n", " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n", " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n", " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n", " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n", " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n", " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n", " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n", " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n", " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n", " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n", " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n", " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n", " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n", " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n", " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n", " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n", " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n", " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n", " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n", " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n", " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n", " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n", " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n", " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n", " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n", " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n", " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n", " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n", " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n", " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n", " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n", " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n", " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n", " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n", " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n", " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n", " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n", " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n", " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n", " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n", " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n", " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635])}\n", "✅ train_data 已保存到 train_data.pt\n" ] } ], "source": [ "import json\n", "import torch\n", "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "tokenizer.pad_token = tokenizer.eos_token \n", "\n", "json_path = \"final_Graph.json\"\n", "with open(json_path, \"r\") as f:\n", " data = json.load(f)\n", "\n", "train_data = []\n", "\n", "\n", "for sample in data:\n", " conversations = sample.get(\"conversations\", [])\n", " embeddings = sample.get(\"embedding\", []) \n", "\n", " if not isinstance(embeddings, list) or len(embeddings) == 0:\n", " print(f\"无效的 embedding,跳过样本:{sample}\")\n", " continue\n", "\n", " graph_embedding = torch.tensor(embeddings, dtype=torch.float32).squeeze(0) # [512]\n", "\n", " #拼接所有对话\n", " dialogue_text = \"\"\n", " for conv in conversations:\n", " role = conv[\"from\"] # \"human\" 或 \"gpt\"\n", " content = conv[\"value\"]\n", " content = content.replace(\"\", \"\") #去掉 \n", " role_token = ROLE_TOKENS.get(role, f\"<|{role}|>\") # 兼容性处理\n", " dialogue_text += f\"{role_token} {content}\\n\"\n", "\n", " tokenized = tokenizer(\n", " dialogue_text,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=max_seq_length - GRAPH_LENGTH, # 预留 graph embedding 空间\n", " return_tensors=\"pt\",\n", " )\n", "\n", " input_ids = tokenized[\"input_ids\"].squeeze(0)\n", " attention_mask = tokenized[\"attention_mask\"].squeeze(0)\n", "\n", " train_data.append({\n", " \"input_ids\": input_ids,\n", " \"attention_mask\": attention_mask,\n", " \"labels\": input_ids.clone(),\n", " \"graph_embedding\": graph_embedding, # `graph_embedding` 存入\n", " })\n", "\n", "print(\"🚀 处理后数据条数:\", len(train_data))\n", "print(\"✅ 示例数据:\", train_data[0])\n", "torch.save(train_data, \"train_data.pt\")\n", "print(\"✅ train_data 已保存到 train_data.pt\")\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "a33bffb9-2ff9-4a4d-af2c-b89b30a69f7d", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train_data 重新加载成功,数据量: 12384\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n", "/usr/local/lib/python3.10/dist-packages/galore_torch/adamw.py:49: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m675775971\u001b[0m (\u001b[33myifang_zhao\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.19.7" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /workspace/wandb/run-20250304_074031-ofm5zhvd" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run experi0304 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/yifang_zhao/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/yifang_zhao/huggingface/runs/ofm5zhvd" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [ 89/5310 01:06 < 1:06:24, 1.31 it/s, Epoch 0.05/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
500.000000

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 150\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;66;03m# ✅ 训练\u001b[39;00m\n\u001b[1;32m 144\u001b[0m trainer \u001b[38;5;241m=\u001b[39m GraphTrainer(\n\u001b[1;32m 145\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 146\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[1;32m 147\u001b[0m train_dataset\u001b[38;5;241m=\u001b[39mtrain_dataset,\n\u001b[1;32m 148\u001b[0m )\n\u001b[0;32m--> 150\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m trainer\u001b[38;5;241m.\u001b[39mpush_to_hub()\n\u001b[1;32m 152\u001b[0m trainer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/workspace/model\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2232\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[0;32m-> 2232\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2233\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2234\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2235\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2236\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2237\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2238\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2239\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2548\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2541\u001b[0m context \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2542\u001b[0m functools\u001b[38;5;241m.\u001b[39mpartial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mno_sync, model\u001b[38;5;241m=\u001b[39mmodel)\n\u001b[1;32m 2543\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(batch_samples) \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 2544\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m!=\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED\n\u001b[1;32m 2545\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext\n\u001b[1;32m 2546\u001b[0m )\n\u001b[1;32m 2547\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[0;32m-> 2548\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2550\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2551\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2554\u001b[0m ):\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2556\u001b[0m tr_loss \u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m+\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3740\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 3737\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m==\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mDEEPSPEED:\n\u001b[1;32m 3738\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscale_wrt_gas\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 3740\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3742\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach()\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py:2325\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 2324\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 2325\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2326\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m learning_rate \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhas_lomo_optimizer:\n\u001b[1;32m 2327\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py:492\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 484\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 485\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 490\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 492\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 493\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 246\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 248\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 251\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 252\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 255\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 256\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "import json\n", "import torch\n", "import os\n", "from transformers import AutoTokenizer\n", "train_data = torch.load(\"train_data.pt\",weights_only=False)\n", "print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", "if 'train_data' not in globals():\n", " train_data_path = \"train_data.pt\"\n", " \n", " if os.path.exists(train_data_path): #确保文件存在\n", " train_data = torch.load(train_data_path, weights_only=False)\n", " print(\"train_data 重新加载成功,数据量:\", len(train_data))\n", " else:\n", " print(f\"未找到 {train_data_path},请检查路径!\")\n", " exit()\n", "#检查是否已经定义了 MODEL_NAME,否则赋值默认值\n", "if \"MODEL_NAME\" not in globals():\n", " MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\" # 默认模型\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n", "\n", "model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", "\n", "\n", "from torch.utils.data import Dataset\n", "\n", "class GraphDataset(Dataset):\n", " def __init__(self, data):\n", " self.data = data\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " sample = self.data[idx]\n", " return {\n", " \"input_ids\": sample[\"input_ids\"],\n", " \"attention_mask\": sample[\"attention_mask\"],\n", " \"graph_embedding\": sample[\"graph_embedding\"], # 额外输入\n", " \"labels\": sample[\"labels\"],\n", " }\n", "\n", "from transformers import AutoModelForCausalLM\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " self.model = AutoModelForCausalLM.from_pretrained(config)\n", " \n", " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n", " self.graph_proj = nn.Linear(512, config.hidden_size)\n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 变换 graph embedding 到 hidden_size\n", " graph_embedding_token = self.graph_proj(graph_embedding) # (batch_size, hidden_size)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", "from transformers import Trainer\n", "\n", "class GraphTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n", " input_ids = inputs[\"input_ids\"]\n", " attention_mask = inputs[\"attention_mask\"]\n", " labels = inputs[\"labels\"]\n", " graph_embedding = inputs.get(\"graph_embedding\", None) \n", "\n", " if graph_embedding is not None:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " graph_embedding=graph_embedding, \n", " )\n", " else:\n", " outputs = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " loss = outputs.loss\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "\n", "\n", "# ✅ 载入修改后的 `GraphAwareLM` 模型\n", "model = GraphAwareLM.from_pretrained(MODEL_NAME)\n", "# model.config.use_sliding_window_attention = False\n", "\n", "# ✅ 训练参数\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\",\n", " per_device_train_batch_size=7,\n", " eval_strategy=\"no\",\n", " save_strategy=\"steps\",\n", " save_steps=3000,\n", " logging_steps=50,\n", " fp16=True,\n", " optim=\"galore_adamw\",\n", " optim_target_modules=\"all-linear\", # ✅ 让 GaLore 作用于所有线性层\n", " optim_args=\"rank=128,scale=2.0\", # ✅ 低秩分解参数\n", " warmup_steps=1000,\n", " num_train_epochs=3,\n", " push_to_hub=True,\n", " hub_model_id=HF_NAME,\n", " hub_strategy=\"every_save\",\n", " run_name = \"experi0304\"\n", ")\n", "\n", "\n", "# ✅ 转换 `train_data` 为 `Dataset`\n", "train_dataset = GraphDataset(train_data)\n", "\n", "# ✅ 训练\n", "trainer = GraphTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", ")\n", "\n", "trainer.train()\n", "trainer.push_to_hub()\n", "trainer.save_model(\"/workspace/model\")\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "05a48aa8-c597-4ff1-9569-aa210f4f1f5d", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", "\n", " \n", " def __init__(self, config):\n", " super().__init__(config)\n", " self.graph_proj = nn.Linear(512, config.hidden_size)\n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 变换 graph embedding 到 hidden_size\n", " graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", " @classmethod\n", " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n", " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n", " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n", " return model\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "73ae15d9-c9d9-4e64-ac8b-2d5877eac984", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "21c8df04-0dc2-436c-aaaf-74a885f734d9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.\n" ] }, { "data": { "text/plain": [ "Qwen2ForCausalLM(\n", " (model): Qwen2Model(\n", " (embed_tokens): Embedding(151936, 1536)\n", " (layers): ModuleList(\n", " (0-27): 28 x Qwen2DecoderLayer(\n", " (self_attn): Qwen2Attention(\n", " (q_proj): Linear(in_features=1536, out_features=1536, bias=True)\n", " (k_proj): Linear(in_features=1536, out_features=256, bias=True)\n", " (v_proj): Linear(in_features=1536, out_features=256, bias=True)\n", " (o_proj): Linear(in_features=1536, out_features=1536, bias=False)\n", " )\n", " (mlp): Qwen2MLP(\n", " (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)\n", " (up_proj): Linear(in_features=1536, out_features=8960, bias=False)\n", " (down_proj): Linear(in_features=8960, out_features=1536, bias=False)\n", " (act_fn): SiLU()\n", " )\n", " (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n", " (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)\n", " )\n", " )\n", " (norm): Qwen2RMSNorm((1536,), eps=1e-06)\n", " (rotary_emb): Qwen2RotaryEmbedding()\n", " )\n", " (lm_head): Linear(in_features=1536, out_features=151936, bias=False)\n", " (graph_proj): Linear(in_features=512, out_features=1536, bias=True)\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "from transformers import AutoTokenizer\n", "\n", "# 加载 tokenizer\n", "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "# 加载训练好的模型\n", "model_path = \"/workspace/model\"\n", "model = GraphAwareLM.from_pretrained(model_path).to(device)\n", "model.eval() # 设置为推理模式\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "7a8562c0-8d55-4412-8f89-de20bae0f7e9", "metadata": {}, "outputs": [], "source": [ "import json\n", "json_path = \"final_Graph.json\"\n", "with open(json_path, \"r\") as f:\n", " data = json.load(f)\n", "\n", "test_data = data[0]\n", "\n", "conversations = test_data.get(\"conversations\")\n", "embeddings = test_data.get(\"embedding\") \n", "\n", "graph_embedding = torch.tensor(embeddings, dtype=torch.float32).to(device)\n", "\n", "question1 = conversations[4][\"value\"].replace(\"\", \"\").strip()\n", "\n", "from transformers import AutoTokenizer\n", "\n", "# ✅ 输入文本\n", "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "GRAPH_LENGTH = 512\n", "max_seq_length = 1100 + GRAPH_LENGTH\n", "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n", "\n", "input_ids = inputs[\"input_ids\"]\n", "attention_mask = inputs[\"attention_mask\"]\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "62f40327-f102-4259-80a5-8761d5d7d3c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n", " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n", " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n", " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n", " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n", " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n", " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n", " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n", " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n", " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n", " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n", " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n", " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n", " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n", " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n", " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n", " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n", " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n", " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n", " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n", " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n", " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n", " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n", " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n", " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n", " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n", " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n", " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n", " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n", " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n", " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n", " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n", " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n", " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n", " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n", " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n", " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n", " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n", " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n", " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n", " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n", " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n", " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n", " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n", " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n", " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n", " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n", " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n", " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n", " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n", " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n", " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n", " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n", " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n", " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n", " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n", " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n", " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n", " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n", " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n", " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n", " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n", " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n", " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n", " device='cuda:0')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph_embedding" ] }, { "cell_type": "code", "execution_count": 15, "id": "067a0cf7-3010-4b6b-b2aa-d4ce95010d9b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "模型回复: How\n" ] } ], "source": [ "# ✅ 进行前向传播\n", "with torch.no_grad():\n", " outputs = model(input_ids=input_ids, attention_mask=attention_mask, graph_embedding=graph_embedding)\n", "\n", "# ✅ 提取 logits 并进行贪心解码\n", "logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", "predicted_id = torch.argmax(logits, dim=-1) # 选择概率最大的 token\n", "\n", "# ✅ 反向编码为文本\n", "response_text = tokenizer.decode(predicted_id, skip_special_tokens=True)\n", "\n", "print(\"模型回复:\", response_text)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ae38ed68-bc6a-4bc3-aee8-d54d2dd689ef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated Response: Is there any sequential logic in the module, and if so, how is it handled? What are the module's inputs and outputs?\n", "What are the module's inputs and outputs?\n", "What are the module's inputs and outputs?\n", "What are the module's inputs and outputs?\n", "What is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's input, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and what is the module's output, and module's output, and module's input, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output. Is the module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module's output, and module\n" ] } ], "source": [ "max_new_tokens = 1024\n", "generated_ids = input_ids.clone()\n", "generated_attention_mask = attention_mask.clone()\n", "for _ in range(max_new_tokens):\n", " # ✅ 计算 logits 并进行生成\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=generated_ids, # (batch_size, seq_len)\n", " attention_mask=generated_attention_mask, # (batch_size, seq_len)\n", " graph_embedding=graph_embedding, # (batch_size, 512)\n", " )\n", "\n", "\n", " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", " next_token = torch.argmax(logits, dim=-1) # 贪心解码\n", " # print(next_token)\n", "\n", "\n", " # ✅ **拼接到已生成序列**\n", " generated_ids = torch.cat([generated_ids, next_token.unsqueeze(1)], dim=1)\n", "\n", " # print(generated_ids)\n", "\n", " if next_token.item() == tokenizer.eos_token_id:\n", " break\n", "\n", " generated_attention_mask = torch.cat(\n", " [generated_attention_mask, torch.ones((1, 1), dtype=generated_attention_mask.dtype, device=generated_attention_mask.device)], dim=1\n", " ) \n", "\n", "# ✅ 解码最终输出\n", "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "print(\"Generated Response:\", generated_text)" ] }, { "cell_type": "code", "execution_count": 10, "id": "803f41fe-f504-4c2a-96b4-afc2cd437d01", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[151646, 3838, 525, 279, 8286, 17473, 304, 279, 6250,\n", " 50773, 2038, 369, 279, 29952, 4688, 11, 323, 1128,\n", " 525, 862, 9895, 30]], device='cuda:0')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generated_ids" ] }, { "cell_type": "code", "execution_count": null, "id": "87d1396b-4d20-4a76-a092-b26a587a76ac", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }