{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import torch\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "MODEL_NAME = \"/workspace/model\"\n", "model_token = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import torch\n", "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_token)\n", "tokenizer.pad_token = tokenizer.eos_token " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "json_path = \"final_Graph.json\"\n", "with open(json_path, \"r\") as f:\n", " data = json.load(f)\n", "\n", "test_data = data[0]\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "ROLE_TOKENS = {\n", " \"human\": \"<|User|>\", \n", " \"gpt\": \"<|Assistant|>\", \n", "}\n", "GRAPH_LENGTH = 512\n", "max_seq_length = 1100 + GRAPH_LENGTH" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "conversations = test_data.get(\"conversations\")\n", "embeddings = test_data.get(\"embedding\") \n", "\n", "graph_embedding = torch.tensor(embeddings, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'What are the signal definitions in the Verilog code for the calculator module, and what are their purposes?'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "question1 = conversations[0][\"value\"].replace(\"\", \"\").strip()\n", "question1" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import json\n", "import torch\n", "import os\n", "from transformers import AutoTokenizer\n", "# tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "from transformers import Trainer, TrainingArguments, AutoModelForCausalLM\n", "from torch.utils.data import Dataset\n", "from transformers import AutoModelForCausalLM\n", "import torch\n", "import torch.nn as nn\n", "\n", "class GraphAwareLM(AutoModelForCausalLM):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " self.model = AutoModelForCausalLM.from_config(config)\n", " \n", " # ✅ 线性变换,把 512 维的 `graph_embedding` 映射到 `hidden_size`\n", " self.graph_proj = nn.Linear(512, config.hidden_size)\n", "\n", " def forward(self, input_ids=None, attention_mask=None, labels=None, graph_embedding=None):\n", " \"\"\"\n", " `graph_embedding` 形状: (batch_size, 512)\n", " `input_ids` 形状: (batch_size, seq_len)\n", " \"\"\"\n", " # ✅ 获取 token embedding\n", " inputs_embeds = self.model.get_input_embeddings()(input_ids) # (batch_size, seq_len, hidden_size)\n", "\n", " # ✅ 变换 graph embedding 到 hidden_size\n", " graph_embedding_token = self.graph_proj(graph_embedding.squeeze(0)) # (batch_size, hidden_size)\n", "\n", " # ✅ 在 `inputs_embeds` 前面拼接 graph_embedding\n", " graph_embedding_token = graph_embedding_token.unsqueeze(1) # (batch_size, 1, hidden_size)\n", " inputs_embeds = torch.cat([graph_embedding_token, inputs_embeds], dim=1) # (batch_size, seq_len+1, hidden_size)\n", "\n", " # ✅ 调整 attention mask\n", " if attention_mask is not None:\n", " graph_mask = torch.ones((attention_mask.shape[0], 1), device=attention_mask.device, dtype=attention_mask.dtype)\n", " attention_mask = torch.cat([graph_mask, attention_mask], dim=1) # (batch_size, seq_len+1)\n", "\n", " # ✅ 传入模型\n", " outputs = self.model(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", "\n", " return outputs\n", "\n", " @classmethod\n", " def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n", " model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)\n", " model.graph_proj = nn.Linear(512, model.config.hidden_size)\n", " return model\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "model = GraphAwareLM.from_pretrained(MODEL_NAME).to(device)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-2.4214, -0.5552, 1.0389, -1.3428, -0.1341, 0.6100, -0.4200, -1.8584,\n", " -0.2880, -0.4779, 0.3452, -0.8934, -0.9216, 0.5600, 0.2474, -0.9009,\n", " -1.0995, 0.6065, 1.7662, -1.2281, 0.0000, -1.9196, 0.1920, -1.2770,\n", " -0.6918, -1.3762, -0.7639, -0.1023, 2.5149, 1.1990, -0.2678, -0.7488,\n", " -0.0000, 0.9108, 0.2010, -0.2639, 0.5023, -0.8752, 0.2083, 0.5740,\n", " 0.3758, -0.7036, -1.3210, -0.8119, -0.5329, -0.2355, -0.2750, 1.6133,\n", " -2.3233, 0.3174, 0.0000, 0.5769, 0.3558, 0.2234, -0.0666, -0.6310,\n", " -0.3533, 0.9497, -0.9576, 0.1615, -0.0460, -1.1686, 1.4337, -1.2952,\n", " -1.1095, 0.5081, -1.9626, -0.3278, 0.7837, -2.4616, 0.3936, -0.3157,\n", " -1.6531, -0.0708, -0.6630, 0.4285, 0.1360, -0.7986, -0.1449, 0.0000,\n", " 0.9076, 0.7794, 0.6391, 0.9840, 0.2970, 1.5463, 1.1554, -0.5432,\n", " 0.7202, 0.0000, -0.2380, 0.0422, 0.0000, 0.4296, 0.2068, 0.3330,\n", " -0.5888, 0.0000, 1.0656, -0.2724, 0.7562, -0.6863, -1.6948, -0.1634,\n", " 1.8262, 1.4235, 0.9178, -0.7475, -0.2682, 0.5534, 1.5643, -0.9898,\n", " -0.2911, 1.3752, 0.6331, -0.1162, 1.7250, 0.8486, -0.0000, -1.6454,\n", " -4.2099, -0.1101, 0.9528, -0.1335, 0.1057, 0.2624, 2.4600, 1.2772,\n", " -3.6113, -1.6540, 1.7807, -0.5077, 0.4537, 1.0987, -0.0713, 0.1391,\n", " -0.0000, -1.3129, 0.5611, -0.3687, -0.7690, 0.0190, 0.9332, -0.4274,\n", " -0.4125, -0.6608, 0.4810, -0.6759, -0.8501, 0.0000, -1.6998, 0.3269,\n", " 0.0334, -0.8513, -0.8695, -0.2957, -2.1983, 1.1621, 0.1864, 0.6089,\n", " 0.4840, -0.6849, 0.2127, 0.7035, -2.9177, 2.2954, -2.0283, -2.1883,\n", " -0.0000, 0.1591, 1.3046, -0.0000, 0.2811, 0.0935, -1.0028, 0.8179,\n", " 1.5387, 0.5271, 0.2195, -0.0882, -1.3943, 0.8263, 0.7164, 0.6240,\n", " 0.7027, -0.5830, -1.2238, -0.0000, 0.5721, 0.0000, 0.3103, 0.7294,\n", " -0.0224, 2.8884, -0.0000, -0.0000, 2.1562, -0.6177, 1.5242, -0.0000,\n", " -0.9023, -0.0000, 1.9196, -0.9594, -0.7334, 0.6636, 0.0000, 0.5613,\n", " -0.3294, 1.1782, -0.8789, 1.6285, 0.3845, 0.1210, 1.3321, 0.5566,\n", " -0.4729, 1.9552, -0.6409, 1.1379, -0.0000, 1.2146, -0.7578, -0.3764,\n", " -0.0823, -1.7541, -0.1362, -0.1631, -0.6794, 1.2874, 0.2402, 0.0000,\n", " 2.3540, -0.5574, -0.9901, 0.3435, 0.6318, -0.3071, -0.6270, -1.8417,\n", " -1.9213, -0.4928, 0.1969, -1.2195, -0.1594, -1.1694, 1.9461, 1.4360,\n", " -0.4050, 1.3495, 0.3053, -0.3500, -0.1546, -0.4096, 0.8011, -0.5379,\n", " -0.1322, 0.0000, 1.7025, -0.0000, -0.7611, 1.4174, -1.0466, -0.8641,\n", " 0.3074, -0.9910, 0.0000, 1.2856, -0.3916, -1.4133, -1.2143, -1.1373,\n", " -0.4996, -0.3315, 1.6280, 0.1051, 0.3570, 2.4021, -0.0249, 0.8169,\n", " -0.4497, -1.4486, -0.0000, -0.7351, -0.3337, 0.2480, -0.5413, 2.2289,\n", " 1.6903, 0.7866, 0.6164, 0.8920, -1.1745, -0.3534, -0.4512, 0.0000,\n", " -0.3795, -1.2503, -0.5114, 1.6374, 1.3271, 1.8410, 0.1040, 0.9731,\n", " -0.3357, 2.4072, -0.0000, 1.9666, -0.5907, 1.0771, 1.6236, -0.9991,\n", " -0.0282, 0.6689, -1.0429, 0.9279, 0.0000, -0.1722, -1.0940, -1.1756,\n", " -0.2457, -1.1142, -1.5693, 1.7408, 1.8951, -1.5109, -0.3783, -0.4719,\n", " -0.7410, -0.2575, 0.0000, -0.8207, -0.6377, -1.2434, 0.4213, -2.1689,\n", " 1.1191, 0.8991, -0.7343, -0.0000, 0.1287, -1.0638, -1.3629, -0.0916,\n", " 0.6016, -1.2285, 2.1858, -0.1274, -0.1246, 0.8666, -0.1599, -0.9024,\n", " -0.6486, 0.9323, 1.4422, -0.7030, 1.6400, 1.2095, 0.9178, -0.6975,\n", " 1.5239, -1.8692, -2.4644, -0.0000, 1.3411, -0.0351, 1.9389, 1.3991,\n", " -1.0556, -0.8072, 0.9237, 0.8799, 0.2778, -0.8607, 0.4810, -0.0000,\n", " 0.8293, 0.0735, 2.2176, -0.0000, -0.4048, 0.8768, -1.4589, -2.3772,\n", " -0.5785, 0.7544, -1.3414, 0.7273, -1.4420, 2.0120, -0.0846, -1.0264,\n", " -0.8520, -0.3899, -0.0000, -0.5772, -0.1395, -0.8346, 2.7815, 0.3414,\n", " 2.6266, 0.2384, 2.0168, 0.6710, 0.9409, -0.3611, 1.6438, -0.0000,\n", " -0.8750, -0.1610, 0.8060, -1.5453, 0.3108, -0.6887, 0.0000, 0.3937,\n", " 0.2050, -0.7704, 1.1102, 0.1719, -0.4513, -0.1844, 0.7308, -2.4639,\n", " -0.1578, -0.5711, -0.4696, -0.8899, 0.0929, -0.2267, 0.1619, 0.7937,\n", " -0.3767, 0.2024, 0.3893, -0.7677, 1.5729, -0.6239, -0.0000, 0.8411,\n", " 0.6361, -1.1110, -1.2833, 1.0356, -0.9941, 0.5842, -0.7817, -0.5730,\n", " 0.2732, -0.6890, -0.0000, -0.0087, 1.3772, 0.3003, 0.0000, 0.8828,\n", " -1.7060, -0.9499, 0.0000, 1.2618, -0.1124, 0.9352, 0.5854, 1.1139,\n", " 0.1583, 3.3464, -0.4027, 0.5860, -0.8730, -0.0163, -0.7023, 2.1778,\n", " -3.2313, 1.5753, 0.8494, -1.3516, -2.2013, -1.6432, 0.2581, 0.2197,\n", " -0.7742, -0.6365, -2.4008, 1.4902, 0.3697, -0.2428, 0.0000, -0.6978,\n", " -0.0000, 0.7576, 1.7998, 0.0000, -0.8300, -1.0503, 0.4118, 1.4737,\n", " -1.0162, -1.1784, -0.3985, 0.1699, -0.0000, -0.6951, -1.5820, 1.2909,\n", " 1.7528, 0.1409, -1.3121, 1.7415, 0.5114, -1.7321, 2.0781, 0.5635]],\n", " device='cuda:0')" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import AutoTokenizer\n", "\n", "# ✅ 加载分词器\n", "tokenizer = AutoTokenizer.from_pretrained(model_token)\n", "\n", "# ✅ 输入文本\n", "inputs = tokenizer(question1, return_tensors=\"pt\",truncation=True,max_length=max_seq_length - GRAPH_LENGTH).to(device)\n", "\n", "graph_embedding.to(device)\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[14], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(max_new_tokens):\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# ✅ 计算 logits 并进行生成\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m----> 6\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgenerated_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mattention_mask\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, seq_len)\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mgraph_embedding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgraph_embedding\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# (batch_size, 512)\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 13\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :] \u001b[38;5;66;03m# 取最后一个 token 的 logits\u001b[39;00m\n\u001b[1;32m 14\u001b[0m next_token \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39margmax(logits, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;66;03m# 贪心解码\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py:172\u001b[0m, in \u001b[0;36mdeprecate_kwarg..wrapper..wrapped_func\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m minimum_action \u001b[38;5;129;01min\u001b[39;00m (Action\u001b[38;5;241m.\u001b[39mNOTIFY, Action\u001b[38;5;241m.\u001b[39mNOTIFY_ALWAYS) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torchdynamo_compiling():\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# DeprecationWarning is ignored by default, so we use FutureWarning instead\u001b[39;00m\n\u001b[1;32m 170\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(message, \u001b[38;5;167;01mFutureWarning\u001b[39;00m, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:856\u001b[0m, in \u001b[0;36mQwen2ForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 853\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 855\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m--> 856\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 857\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 858\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 859\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 860\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 861\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 862\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 863\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 864\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 865\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 866\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 867\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 868\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 870\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 871\u001b[0m \u001b[38;5;66;03m# Only compute necessary logits, and do not upcast them to float if we are not computing the loss\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:579\u001b[0m, in \u001b[0;36mQwen2Model.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)\u001b[0m\n\u001b[1;32m 567\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 568\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 569\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 576\u001b[0m position_embeddings,\n\u001b[1;32m 577\u001b[0m )\n\u001b[1;32m 578\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 579\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 580\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 582\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 583\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 584\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 585\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 586\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 587\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 588\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflash_attn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 591\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 593\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_attentions:\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:260\u001b[0m, in \u001b[0;36mQwen2DecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 257\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 260\u001b[0m hidden_states, self_attn_weights \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 265\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 271\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 273\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:192\u001b[0m, in \u001b[0;36mQwen2Attention.forward\u001b[0;34m(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 190\u001b[0m attention_interface \u001b[38;5;241m=\u001b[39m ALL_ATTENTION_FUNCTIONS[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_attn_implementation]\n\u001b[0;32m--> 192\u001b[0m attn_output, attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattention_interface\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 193\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 194\u001b[0m \u001b[43m \u001b[49m\u001b[43mquery_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 195\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 196\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 197\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mdropout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention_dropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mscaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[43m \u001b[49m\u001b[43msliding_window\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msliding_window\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# main diff with Llama\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 202\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 204\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_output\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m*\u001b[39minput_shape, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mcontiguous()\n\u001b[1;32m 205\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mo_proj(attn_output)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py:123\u001b[0m, in \u001b[0;36meager_attention_forward\u001b[0;34m(module, query, key, value, attention_mask, scaling, dropout, **kwargs)\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 122\u001b[0m causal_mask \u001b[38;5;241m=\u001b[39m attention_mask[:, :, :, : key_states\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m]]\n\u001b[0;32m--> 123\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m \u001b[43mattn_weights\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mcausal_mask\u001b[49m\n\u001b[1;32m 125\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39msoftmax(attn_weights, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mto(query\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 126\u001b[0m attn_weights \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mdropout(attn_weights, p\u001b[38;5;241m=\u001b[39mdropout, training\u001b[38;5;241m=\u001b[39mmodule\u001b[38;5;241m.\u001b[39mtraining)\n", "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (22) at non-singleton dimension 3" ] } ], "source": [ "\n", "generated_ids = inputs[\"input_ids\"]\n", "max_new_tokens = 1024\n", "for _ in range(max_new_tokens):\n", " # ✅ 计算 logits 并进行生成\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=generated_ids, # (batch_size, seq_len)\n", " attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n", " graph_embedding=graph_embedding, # (batch_size, 512)\n", " )\n", "\n", "\n", " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", " next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n", "\n", "\n", " # ✅ **拼接到已生成序列**\n", " generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n", "\n", " if next_token[:, 0] == tokenizer.eos_token_id:\n", " break\n", "\n", "# ✅ 解码最终输出\n", "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "print(\"Generated Response:\", generated_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generated Response: How does the code handle combinational logic? What are the signal definitions in the Verilog code for the 4-to-1 multiplexer?\n", "The code uses assign statements to handle combinational logic. The first assign statement selects between the four inputs (in0, in1, in2, in3) based on the select signals (s0, s1) and assigns the result to the output (out). The second assign statement uses a ternary operator to check the value of the select signals (s0, s1) and assigns the corresponding input to the output (out). The signal definitions include in0, in1, in2, in3 as data inputs, s0 and s1 as select signals, and out as the output signal.\n", "How does the code handle sequential logic? What are the signal definitions in the sequential logic part of the Verilog code?\n", "The sequential logic part of the code uses an always block with a sensitivity list that includes posedge clk, indicating that it is a sequential logic block. The output (out) is updated on the rising edge of the clock signal (clk). The input (in0) is also included in the sensitivity list, but since it is not used in the logic, it might be a mistake or an unused input. The sequential logic part is the clocked flip-flop that updates the output (out) based on the current value of the input (in0) and the select signals (s0, s1).\n", "What is the function of the circuit described in the Verilog code?\n", "The circuit is a 4-to-1 multiplexer with a registered output. It selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop on the rising edge of the clock signal (clk). The output (out) is the value of the selected input stored in the flip-flop.\n", "How can the circuit be implemented in hardware?\n", "The circuit can be implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The multiplexer can be constructed using AND-OR gates or transmission gates, and the output of the multiplexer can be connected to the D input of the flip-flop. The clock signal (clk) should be connected to the clock input of the flip-flop. The select signals (s0, s1) should be connected to the control inputs of the multiplexer. The data inputs (in0, in1, in2, in3) should be connected to the respective inputs of the multiplexer. The output of the flip-flop (out) should be connected to the output of the circuit. It is important to ensure that the timing constraints for the clock signal (clk) are met to avoid setup and hold time violations. The unused input (in0) in the sensitivity list of the always block might indicate a mistake in the code, as it is not used in the logic. However, it could be a typo or an oversight in the code. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop. The unused input (in0) should be noted as a potential issue but should not affect the functionality of the circuit as described in the code. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit, which are the multiplexer and the flip-flop, while noting the potential issue of the unused input (in0) in the sensitivity list of the always block. The circuit is a 4-to-1 multiplexer with a registered output, where the output is updated on the rising edge of the clock signal (clk). The multiplexer selects one of the four inputs based on the select signals (s0, s1) and stores the selected value in a flip-flop. The circuit is implemented using standard logic gates for the multiplexer and a D flip-flop for the registered output. The implementation should focus on the functional parts of the circuit\n" ] } ], "source": [ "generated_ids = inputs[\"input_ids\"]\n", "max_new_tokens = 1024\n", "for _ in range(max_new_tokens):\n", " # ✅ 计算 logits 并进行生成\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=generated_ids, # (batch_size, seq_len)\n", " attention_mask=inputs[\"attention_mask\"], # (batch_size, seq_len)\n", " graph_embedding=graph_embedding, # (batch_size, 512)\n", " )\n", "\n", "\n", " logits = outputs.logits[:, -1, :] # 取最后一个 token 的 logits\n", " next_token = torch.argmax(logits, dim=-1, keepdim=True) # 贪心解码\n", "\n", "\n", " # ✅ **拼接到已生成序列**\n", " generated_ids = torch.cat([generated_ids, next_token], dim=-1)\n", "\n", " if next_token[:, 0] == tokenizer.eos_token_id:\n", " break\n", "\n", "# ✅ 解码最终输出\n", "generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "print(\"Generated Response:\", generated_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoTokenizer\n", "\n", "# 加载 tokenizer\n", "MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "\n", "# 加载训练好的模型\n", "model_path = \"/workspace/model\"\n", "model = GraphAwareLM.from_pretrained(model_path)\n", "model.eval() # 设置为推理模式\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }