Spaces:
Sleeping
Sleeping
Upload 6 files (#5)
Browse files- Upload 6 files (f7a18c9b20fe67e6e2fa85becf0bd3103cfce0fb)
Co-authored-by: Feiyu Chen <[email protected]>
- app.py +65 -150
- cache_utils.py +322 -0
- configuration_phi.py +195 -0
- modeling_attn_mask_utils.py +500 -0
- modeling_phi.py +836 -0
- tokenization_codegen.py +389 -0
app.py
CHANGED
@@ -3,15 +3,73 @@ import gradio as gr
|
|
3 |
import os
|
4 |
import torch
|
5 |
import random
|
6 |
-
import nltk_u
|
7 |
import pandas as pd
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
import time
|
10 |
|
11 |
-
from model import RNN_model
|
12 |
from timeit import default_timer as timer
|
13 |
from typing import Tuple, Dict
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Import data
|
16 |
df= pd.read_csv('Symptom2Disease.csv')
|
17 |
df.drop('Unnamed: 0', axis= 1, inplace= True)
|
@@ -19,72 +77,7 @@ df.drop('Unnamed: 0', axis= 1, inplace= True)
|
|
19 |
# Preprocess data
|
20 |
df.drop_duplicates(inplace= True)
|
21 |
train_data, test_data= train_test_split(df, test_size=0.15, random_state=42 )
|
22 |
-
|
23 |
-
# Setup class names
|
24 |
-
class_names= {0: 'Acne',
|
25 |
-
1: 'Arthritis',
|
26 |
-
2: 'Bronchial Asthma',
|
27 |
-
3: 'Cervical spondylosis',
|
28 |
-
4: 'Chicken pox',
|
29 |
-
5: 'Common Cold',
|
30 |
-
6: 'Dengue',
|
31 |
-
7: 'Dimorphic Hemorrhoids',
|
32 |
-
8: 'Fungal infection',
|
33 |
-
9: 'Hypertension',
|
34 |
-
10: 'Impetigo',
|
35 |
-
11: 'Jaundice',
|
36 |
-
12: 'Malaria',
|
37 |
-
13: 'Migraine',
|
38 |
-
14: 'Pneumonia',
|
39 |
-
15: 'Psoriasis',
|
40 |
-
16: 'Typhoid',
|
41 |
-
17: 'Varicose Veins',
|
42 |
-
18: 'allergy',
|
43 |
-
19: 'diabetes',
|
44 |
-
20: 'drug reaction',
|
45 |
-
21: 'gastroesophageal reflux disease',
|
46 |
-
22: 'peptic ulcer disease',
|
47 |
-
23: 'urinary tract infection'
|
48 |
-
}
|
49 |
-
|
50 |
-
vectorizer= nltk_u.vectorizer()
|
51 |
-
vectorizer.fit(train_data.text)
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
# Model and transforms preparation
|
56 |
-
model= RNN_model()
|
57 |
-
# Load state dict
|
58 |
-
model.load_state_dict(torch.load(
|
59 |
-
f= 'pretrained_symtom_to_disease_model.pth',
|
60 |
-
map_location= torch.device('cpu')))
|
61 |
-
# Disease Advice
|
62 |
-
disease_advice = {
|
63 |
-
'Acne': "Maintain a proper skincare routine, avoid excessive touching of the affected areas, and consider using over-the-counter topical treatments. If severe, consult a dermatologist.",
|
64 |
-
'Arthritis': "Stay active with gentle exercises, manage weight, and consider pain-relief strategies like hot/cold therapy. Consult a rheumatologist for tailored guidance.",
|
65 |
-
'Bronchial Asthma': "Follow prescribed inhaler and medication regimen, avoid triggers like smoke and allergens, and have an asthma action plan. Regular check-ups with a pulmonologist are important.",
|
66 |
-
'Cervical spondylosis': "Maintain good posture, do neck exercises, and use ergonomic support. Physical therapy and pain management techniques might be helpful.",
|
67 |
-
'Chicken pox': "Rest, maintain hygiene, and avoid scratching. Consult a doctor for appropriate antiviral treatment.",
|
68 |
-
'Common Cold': "Get plenty of rest, stay hydrated, and consider over-the-counter remedies for symptom relief. Seek medical attention if symptoms worsen or last long.",
|
69 |
-
'Dengue': "Stay hydrated, rest, and manage fever with acetaminophen. Seek medical care promptly, as dengue can escalate quickly.",
|
70 |
-
'Dimorphic Hemorrhoids': "Follow a high-fiber diet, maintain good hygiene, and consider stool softeners. Consult a doctor if symptoms persist.",
|
71 |
-
'Fungal infection': "Keep the affected area clean and dry, use antifungal creams, and avoid sharing personal items. Consult a dermatologist if it persists.",
|
72 |
-
'Hypertension': "Follow a balanced diet, exercise regularly, reduce salt intake, and take prescribed medications. Regular check-ups with a healthcare provider are important.",
|
73 |
-
'Impetigo': "Keep the affected area clean, use prescribed antibiotics, and avoid close contact. Consult a doctor for proper treatment.",
|
74 |
-
'Jaundice': "Get plenty of rest, maintain hydration, and follow a doctor's advice for diet and medications. Regular monitoring is important.",
|
75 |
-
'Malaria': "Take prescribed antimalarial medications, rest, and manage fever. Seek medical attention for severe cases.",
|
76 |
-
'Migraine': "Identify triggers, manage stress, and consider pain-relief medications. Consult a neurologist for personalized management.",
|
77 |
-
'Pneumonia': "Follow prescribed antibiotics, rest, stay hydrated, and monitor symptoms. Seek immediate medical attention for severe cases.",
|
78 |
-
'Psoriasis': "Moisturize, use prescribed creams, and avoid triggers. Consult a dermatologist for effective management.",
|
79 |
-
'Typhoid': "Take prescribed antibiotics, rest, and stay hydrated. Dietary precautions are important. Consult a doctor for proper treatment.",
|
80 |
-
'Varicose Veins': "Elevate legs, exercise regularly, and wear compression stockings. Consult a vascular specialist for evaluation and treatment options.",
|
81 |
-
'allergy': "Identify triggers, manage exposure, and consider antihistamines. Consult an allergist for comprehensive management.",
|
82 |
-
'diabetes': "Follow a balanced diet, exercise, monitor blood sugar levels, and take prescribed medications. Regular visits to an endocrinologist are essential.",
|
83 |
-
'drug reaction': "Discontinue the suspected medication, seek medical attention if symptoms are severe, and inform healthcare providers about the reaction.",
|
84 |
-
'gastroesophageal reflux disease': "Follow dietary changes, avoid large meals, and consider medications. Consult a doctor for personalized management.",
|
85 |
-
'peptic ulcer disease': "Avoid spicy and acidic foods, take prescribed medications, and manage stress. Consult a gastroenterologist for guidance.",
|
86 |
-
'urinary tract infection': "Stay hydrated, take prescribed antibiotics, and maintain good hygiene. Consult a doctor for appropriate treatment."
|
87 |
-
}
|
88 |
|
89 |
howto= """Welcome to the <b>Medical Chatbot</b>, powered by Gradio.
|
90 |
Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL.
|
@@ -92,8 +85,6 @@ Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms a
|
|
92 |
The bot will respond based on the best possible answers to your messages.
|
93 |
|
94 |
"""
|
95 |
-
|
96 |
-
|
97 |
# Create the gradio demo
|
98 |
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") as demo:
|
99 |
gr.HTML('<h1 align="center">Medical Chatbot: ARIN 7102')
|
@@ -105,87 +96,11 @@ with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;}
|
|
105 |
clear = gr.ClearButton([msg, chatbot])
|
106 |
|
107 |
def respond(message, chat_history):
|
108 |
-
# Random greetings in list format
|
109 |
-
greetings = [
|
110 |
-
"hello!",'hello', 'hii !', 'hi', "hi there!", "hi there!", "heyy", 'good morning', 'good afternoon', 'good evening'
|
111 |
-
"hey", "how are you", "how are you?", "how is it going", "how is it going?",
|
112 |
-
"what's up?", "how are you?",
|
113 |
-
"hey, how are you?", "what is popping"
|
114 |
-
"good to see you!", "howdy!",
|
115 |
-
"hi, nice to meet you.", "hiya!",
|
116 |
-
"hi", "hi, what's new?",
|
117 |
-
"hey, how's your day?", "hi, how have you been?", "greetings",
|
118 |
-
]
|
119 |
-
# Random Greetings responses
|
120 |
-
responses = [
|
121 |
-
"Thank you for using our medical chatbot. Please provide the symptoms you're experiencing, and I'll do my best to predict the possible disease.",
|
122 |
-
"Hello! I'm here to help you with medical predictions based on your symptoms. Please describe your symptoms in as much detail as possible.",
|
123 |
-
"Greetings! I am a specialized medical chatbot trained to predict potential diseases based on the symptoms you provide. Kindly list your symptoms explicitly.",
|
124 |
-
"Welcome to the medical chatbot. To assist you accurately, please share your symptoms in explicit detail.",
|
125 |
-
"Hi there! I'm a medical chatbot specialized in analyzing symptoms to suggest possible diseases. Please provide your symptoms explicitly.",
|
126 |
-
"Hey! I'm your medical chatbot. Describe your symptoms with as much detail as you can, and I'll generate potential disease predictions.",
|
127 |
-
"How can I assist you today? I'm a medical chatbot trained to predict diseases based on symptoms. Please be explicit while describing your symptoms.",
|
128 |
-
"Hello! I'm a medical chatbot capable of predicting diseases based on the symptoms you provide. Your explicit symptom description will help me assist you better.",
|
129 |
-
"Greetings! I'm here to help with medical predictions. Describe your symptoms explicitly, and I'll offer insights into potential diseases.",
|
130 |
-
"Hi, I'm the medical chatbot. I've been trained to predict diseases from symptoms. The more explicit you are about your symptoms, the better I can assist you.",
|
131 |
-
"Hi, I specialize in medical predictions based on symptoms. Kindly provide detailed symptoms for accurate disease predictions.",
|
132 |
-
"Hello! I'm a medical chatbot with expertise in predicting diseases from symptoms. Please describe your symptoms explicitly to receive accurate insights.",
|
133 |
-
]
|
134 |
-
# Random goodbyes
|
135 |
-
goodbyes = [
|
136 |
-
"farewell!",'bye', 'goodbye','good-bye', 'good bye', 'bye', 'thank you', 'later', "take care!",
|
137 |
-
"see you later!", 'see you', 'see ya', 'see-you', 'thanks', 'thank', 'bye bye', 'byebye'
|
138 |
-
"catch you on the flip side!", "adios!",
|
139 |
-
"goodbye for now!", "till we meet again!",
|
140 |
-
"so long!", "hasta la vista!",
|
141 |
-
"bye-bye!", "keep in touch!",
|
142 |
-
"toodles!", "ciao!",
|
143 |
-
"later, gator!", "stay safe and goodbye!",
|
144 |
-
"peace out!", "until next time!", "off I go!",
|
145 |
-
]
|
146 |
-
# Random Goodbyes responses
|
147 |
-
goodbye_replies = [
|
148 |
-
"Take care of yourself! If you have more questions, don't hesitate to reach out.",
|
149 |
-
"Stay well! Remember, I'm here if you need further medical advice.",
|
150 |
-
"Goodbye for now! Don't hesitate to return if you need more information in the future.",
|
151 |
-
"Wishing you good health ahead! Feel free to come back if you have more concerns.",
|
152 |
-
"Farewell! If you have more symptoms or questions, don't hesitate to consult again.",
|
153 |
-
"Take care and stay informed about your health. Feel free to chat anytime.",
|
154 |
-
"Bye for now! Remember, your well-being is a priority. Don't hesitate to ask if needed.",
|
155 |
-
"Have a great day ahead! If you need medical guidance later on, I'll be here.",
|
156 |
-
"Stay well and take it easy! Reach out if you need more medical insights.",
|
157 |
-
"Until next time! Prioritize your health and reach out if you need assistance.",
|
158 |
-
"Goodbye! Your health matters. Feel free to return if you have more health-related queries.",
|
159 |
-
"Stay healthy and stay curious about your health! If you need more info, just ask.",
|
160 |
-
"Wishing you wellness on your journey! If you have more questions, I'm here to help.",
|
161 |
-
"Take care and remember, your health is important. Don't hesitate to reach out if needed.",
|
162 |
-
"Goodbye for now! Stay informed and feel free to consult if you require medical advice.",
|
163 |
-
"Stay well and stay proactive about your health! If you have more queries, feel free to ask.",
|
164 |
-
"Farewell! Remember, I'm here whenever you need reliable medical information.",
|
165 |
-
"Bye for now! Stay vigilant about your health and don't hesitate to return if necessary.",
|
166 |
-
"Take care and keep your well-being a priority! Reach out if you have more health questions.",
|
167 |
-
"Wishing you good health ahead! Don't hesitate to chat if you need medical insights.",
|
168 |
-
"Goodbye! Stay well and remember, I'm here to assist you with medical queries.",
|
169 |
-
]
|
170 |
-
|
171 |
# Create couple of if-else statements to capture/mimick peoples's Interaction
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
bot_message= random.choice(goodbye_replies)
|
177 |
-
else:
|
178 |
-
#bot_message= random.choice(goodbye_replies)
|
179 |
-
|
180 |
-
transform_text= vectorizer.transform([message])
|
181 |
-
transform_text= torch.tensor(transform_text.toarray()).to(torch.float32)
|
182 |
-
model.eval()
|
183 |
-
with torch.inference_mode():
|
184 |
-
y_logits=model(transform_text)
|
185 |
-
pred_prob= torch.argmax(torch.softmax(y_logits, dim=1), dim=1)
|
186 |
-
|
187 |
-
test_pred= class_names[pred_prob.item()]
|
188 |
-
bot_message = f' Based on your symptoms, I believe you are having {test_pred} and I would advice you {disease_advice[test_pred]}'
|
189 |
chat_history.append((message, bot_message))
|
190 |
time.sleep(2)
|
191 |
return "", chat_history
|
|
|
3 |
import os
|
4 |
import torch
|
5 |
import random
|
6 |
+
#import nltk_u
|
7 |
import pandas as pd
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
import time
|
10 |
|
11 |
+
#from model import RNN_model
|
12 |
from timeit import default_timer as timer
|
13 |
from typing import Tuple, Dict
|
14 |
|
15 |
+
################################################################################
|
16 |
+
import argparse
|
17 |
+
import numpy as np
|
18 |
+
import pprint
|
19 |
+
import os
|
20 |
+
import copy
|
21 |
+
from str2bool import str2bool
|
22 |
+
from typing import Dict, Sequence
|
23 |
+
from sentence_transformers import SentenceTransformer
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import json
|
27 |
+
|
28 |
+
import transformers
|
29 |
+
from modeling_phi import PhiForCausalLM
|
30 |
+
from tokenization_codegen import CodeGenTokenizer
|
31 |
+
################################################################################
|
32 |
+
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
#############################################################################################################################
|
35 |
+
|
36 |
+
parser.add_argument('--device_id', type=str, default="0")
|
37 |
+
parser.add_argument('--model', type=str, default="microsoft/phi-2", help="") ## /phi-1.5
|
38 |
+
parser.add_argument('--embedder', type=str, default="BAAI/bge-small-en-v1.5") ## /bge-small-en-v1.5 # bge-m3
|
39 |
+
parser.add_argument('--output_path', type=str, default="/home/henry/Desktop/HKU-DASC7606-A2/Outputs/ARC-Challenge-test", help="") ## -bge-m3
|
40 |
+
parser.add_argument('--start_index', type=int, default=0, help="")
|
41 |
+
parser.add_argument('--end_index', type=int, default=9999, help="")
|
42 |
+
parser.add_argument('--N', type=int, default=8, help="")
|
43 |
+
parser.add_argument('--max_len', type=int, default=1024, help="")
|
44 |
+
parser.add_argument('--prompt_type', type=str, default="v2.0", help="")
|
45 |
+
parser.add_argument('--top_k', type=str2bool, default=True, help="")
|
46 |
+
#############################################################################################################################
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
if torch.cuda.is_available():
|
50 |
+
device = "cuda"
|
51 |
+
print(f'################################################################# device: {device}#################################################################')
|
52 |
+
else:
|
53 |
+
device = "cpu"
|
54 |
+
|
55 |
+
def get_model(base_model: str = "bigcode/starcoder",):
|
56 |
+
tokenizer = CodeGenTokenizer.from_pretrained(base_model)
|
57 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
58 |
+
tokenizer.pad_token = tokenizer.eos_token
|
59 |
+
|
60 |
+
model = PhiForCausalLM.from_pretrained(
|
61 |
+
base_model,
|
62 |
+
device_map="auto",
|
63 |
+
)
|
64 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
65 |
+
|
66 |
+
model.eval()
|
67 |
+
|
68 |
+
return tokenizer, model
|
69 |
+
|
70 |
+
################################################################################
|
71 |
+
|
72 |
+
'''
|
73 |
# Import data
|
74 |
df= pd.read_csv('Symptom2Disease.csv')
|
75 |
df.drop('Unnamed: 0', axis= 1, inplace= True)
|
|
|
77 |
# Preprocess data
|
78 |
df.drop_duplicates(inplace= True)
|
79 |
train_data, test_data= train_test_split(df, test_size=0.15, random_state=42 )
|
80 |
+
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
howto= """Welcome to the <b>Medical Chatbot</b>, powered by Gradio.
|
83 |
Currently, the chatbot can WELCOME YOU, PREDICT DISEASE based on your symptoms and SUGGEST POSSIBLE SOLUTIONS AND RECOMENDATIONS, and BID YOU FAREWELL.
|
|
|
85 |
The bot will respond based on the best possible answers to your messages.
|
86 |
|
87 |
"""
|
|
|
|
|
88 |
# Create the gradio demo
|
89 |
with gr.Blocks(css = """#col_container { margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""") as demo:
|
90 |
gr.HTML('<h1 align="center">Medical Chatbot: ARIN 7102')
|
|
|
96 |
clear = gr.ClearButton([msg, chatbot])
|
97 |
|
98 |
def respond(message, chat_history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
# Create couple of if-else statements to capture/mimick peoples's Interaction
|
100 |
+
embedder = SentenceTransformer(args.embedder, device=device)
|
101 |
+
tokenizer, model = get_model(base_model=args.model)
|
102 |
+
message_embeddings = embedder.encode(message)
|
103 |
+
bot_message = model(message_embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
chat_history.append((message, bot_message))
|
105 |
time.sleep(2)
|
106 |
return "", chat_history
|
cache_utils.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class Cache:
|
7 |
+
"""
|
8 |
+
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def update(
|
12 |
+
self,
|
13 |
+
key_states: torch.Tensor,
|
14 |
+
value_states: torch.Tensor,
|
15 |
+
layer_idx: int,
|
16 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
17 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
18 |
+
"""
|
19 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
key_states (`torch.Tensor`):
|
23 |
+
The new key states to cache.
|
24 |
+
value_states (`torch.Tensor`):
|
25 |
+
The new value states to cache.
|
26 |
+
layer_idx (`int`):
|
27 |
+
The index of the layer to cache the states for.
|
28 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
29 |
+
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
30 |
+
cache to be created.
|
31 |
+
|
32 |
+
Return:
|
33 |
+
A tuple containing the updated key and value states.
|
34 |
+
"""
|
35 |
+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
36 |
+
|
37 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
38 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
39 |
+
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
40 |
+
|
41 |
+
def get_max_length(self) -> Optional[int]:
|
42 |
+
"""Returns the maximum sequence length of the cached states, if there is any."""
|
43 |
+
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
|
44 |
+
|
45 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
46 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
47 |
+
# Cache without size limit -> all cache is usable
|
48 |
+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
49 |
+
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
50 |
+
max_length = self.get_max_length()
|
51 |
+
previous_seq_length = self.get_seq_length(layer_idx)
|
52 |
+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
53 |
+
return max_length - new_seq_length
|
54 |
+
return previous_seq_length
|
55 |
+
|
56 |
+
|
57 |
+
class DynamicCache(Cache):
|
58 |
+
"""
|
59 |
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
60 |
+
|
61 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
62 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self) -> None:
|
66 |
+
self.key_cache: List[torch.Tensor] = []
|
67 |
+
self.value_cache: List[torch.Tensor] = []
|
68 |
+
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
69 |
+
|
70 |
+
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
71 |
+
"""
|
72 |
+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
73 |
+
sequence length.
|
74 |
+
"""
|
75 |
+
if layer_idx < len(self):
|
76 |
+
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
77 |
+
else:
|
78 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
79 |
+
|
80 |
+
def __iter__(self):
|
81 |
+
"""
|
82 |
+
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
83 |
+
keys and values
|
84 |
+
"""
|
85 |
+
for layer_idx in range(len(self)):
|
86 |
+
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
"""
|
90 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
91 |
+
to the number of layers in the model.
|
92 |
+
"""
|
93 |
+
return len(self.key_cache)
|
94 |
+
|
95 |
+
def update(
|
96 |
+
self,
|
97 |
+
key_states: torch.Tensor,
|
98 |
+
value_states: torch.Tensor,
|
99 |
+
layer_idx: int,
|
100 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
101 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
102 |
+
"""
|
103 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
104 |
+
|
105 |
+
Parameters:
|
106 |
+
key_states (`torch.Tensor`):
|
107 |
+
The new key states to cache.
|
108 |
+
value_states (`torch.Tensor`):
|
109 |
+
The new value states to cache.
|
110 |
+
layer_idx (`int`):
|
111 |
+
The index of the layer to cache the states for.
|
112 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
113 |
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
114 |
+
|
115 |
+
Return:
|
116 |
+
A tuple containing the updated key and value states.
|
117 |
+
"""
|
118 |
+
# Update the number of seen tokens
|
119 |
+
if layer_idx == 0:
|
120 |
+
self.seen_tokens += key_states.shape[-2]
|
121 |
+
|
122 |
+
# Update the cache
|
123 |
+
if len(self.key_cache) <= layer_idx:
|
124 |
+
self.key_cache.append(key_states)
|
125 |
+
self.value_cache.append(value_states)
|
126 |
+
else:
|
127 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
128 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
129 |
+
|
130 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
131 |
+
|
132 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
133 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
134 |
+
if len(self.key_cache) <= layer_idx:
|
135 |
+
return 0
|
136 |
+
return self.key_cache[layer_idx].shape[-2]
|
137 |
+
|
138 |
+
def get_max_length(self) -> Optional[int]:
|
139 |
+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
140 |
+
return None
|
141 |
+
|
142 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
143 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
144 |
+
for layer_idx in range(len(self.key_cache)):
|
145 |
+
device = self.key_cache[layer_idx].device
|
146 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
147 |
+
device = self.value_cache[layer_idx].device
|
148 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
149 |
+
|
150 |
+
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
151 |
+
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
152 |
+
legacy_cache = ()
|
153 |
+
for layer_idx in range(len(self)):
|
154 |
+
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
|
155 |
+
return legacy_cache
|
156 |
+
|
157 |
+
@classmethod
|
158 |
+
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
159 |
+
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
160 |
+
cache = cls()
|
161 |
+
if past_key_values is not None:
|
162 |
+
for layer_idx in range(len(past_key_values)):
|
163 |
+
key_states, value_states = past_key_values[layer_idx]
|
164 |
+
cache.update(key_states, value_states, layer_idx)
|
165 |
+
return cache
|
166 |
+
|
167 |
+
|
168 |
+
class SinkCache(Cache):
|
169 |
+
"""
|
170 |
+
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
171 |
+
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
|
172 |
+
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
|
173 |
+
|
174 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
175 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
176 |
+
|
177 |
+
Parameters:
|
178 |
+
window_length (`int`):
|
179 |
+
The length of the context window.
|
180 |
+
num_sink_tokens (`int`):
|
181 |
+
The number of sink tokens. See the original paper for more information.
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
185 |
+
self.key_cache: List[torch.Tensor] = []
|
186 |
+
self.value_cache: List[torch.Tensor] = []
|
187 |
+
self.window_length = window_length
|
188 |
+
self.num_sink_tokens = num_sink_tokens
|
189 |
+
self.cos_sin_cache = {}
|
190 |
+
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def _rotate_half(x):
|
194 |
+
x1 = x[..., : x.shape[-1] // 2]
|
195 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
196 |
+
return torch.cat((-x2, x1), dim=-1)
|
197 |
+
|
198 |
+
def _apply_key_rotary_pos_emb(
|
199 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
200 |
+
) -> torch.Tensor:
|
201 |
+
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
202 |
+
return rotated_key_states
|
203 |
+
|
204 |
+
def _get_rerotation_cos_sin(
|
205 |
+
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
206 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
207 |
+
if key_states.shape[-2] not in self.cos_sin_cache:
|
208 |
+
# Upcast to float32 temporarily for better accuracy
|
209 |
+
cos = cos.to(torch.float32)
|
210 |
+
sin = sin.to(torch.float32)
|
211 |
+
|
212 |
+
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
213 |
+
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
214 |
+
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
215 |
+
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
216 |
+
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
217 |
+
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
218 |
+
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
219 |
+
|
220 |
+
self.cos_sin_cache[key_states.shape[-2]] = (
|
221 |
+
rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
222 |
+
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
223 |
+
)
|
224 |
+
return self.cos_sin_cache[key_states.shape[-2]]
|
225 |
+
|
226 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
227 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
228 |
+
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
229 |
+
if len(self.key_cache) <= layer_idx:
|
230 |
+
return 0
|
231 |
+
return self.key_cache[layer_idx].shape[-2]
|
232 |
+
|
233 |
+
def get_max_length(self) -> Optional[int]:
|
234 |
+
"""Returns the maximum sequence length of the cached states."""
|
235 |
+
return self.window_length
|
236 |
+
|
237 |
+
def update(
|
238 |
+
self,
|
239 |
+
key_states: torch.Tensor,
|
240 |
+
value_states: torch.Tensor,
|
241 |
+
layer_idx: int,
|
242 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
243 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
244 |
+
"""
|
245 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
246 |
+
|
247 |
+
Parameters:
|
248 |
+
key_states (`torch.Tensor`):
|
249 |
+
The new key states to cache.
|
250 |
+
value_states (`torch.Tensor`):
|
251 |
+
The new value states to cache.
|
252 |
+
layer_idx (`int`):
|
253 |
+
The index of the layer to cache the states for.
|
254 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
255 |
+
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
|
256 |
+
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
|
257 |
+
rotation as the tokens are shifted.
|
258 |
+
|
259 |
+
Return:
|
260 |
+
A tuple containing the updated key and value states.
|
261 |
+
"""
|
262 |
+
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
|
263 |
+
# with partially rotated position embeddings, like Phi or Persimmon.
|
264 |
+
sin = cache_kwargs.get("sin")
|
265 |
+
cos = cache_kwargs.get("cos")
|
266 |
+
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
267 |
+
using_rope = cos is not None and sin is not None
|
268 |
+
|
269 |
+
# Update the number of seen tokens
|
270 |
+
if layer_idx == 0:
|
271 |
+
self.seen_tokens += key_states.shape[-2]
|
272 |
+
|
273 |
+
# [bsz, num_heads, seq_len, head_dim]
|
274 |
+
if len(self.key_cache) <= layer_idx:
|
275 |
+
# Empty cache
|
276 |
+
self.key_cache.append(key_states)
|
277 |
+
self.value_cache.append(value_states)
|
278 |
+
|
279 |
+
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
280 |
+
# Growing cache
|
281 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
282 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
283 |
+
|
284 |
+
else:
|
285 |
+
# Shifting cache
|
286 |
+
keys_to_keep = self.key_cache[layer_idx][
|
287 |
+
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
|
288 |
+
]
|
289 |
+
|
290 |
+
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
291 |
+
if using_rope:
|
292 |
+
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
293 |
+
key_states, cos[: self.window_length], sin[: self.window_length]
|
294 |
+
)
|
295 |
+
if partial_rotation_size is not None:
|
296 |
+
keys_to_keep, keys_pass = (
|
297 |
+
keys_to_keep[..., :partial_rotation_size],
|
298 |
+
keys_to_keep[..., partial_rotation_size:],
|
299 |
+
)
|
300 |
+
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
301 |
+
if partial_rotation_size is not None:
|
302 |
+
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
303 |
+
|
304 |
+
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
305 |
+
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
306 |
+
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
307 |
+
|
308 |
+
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
309 |
+
values_to_keep = self.value_cache[layer_idx][
|
310 |
+
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
311 |
+
]
|
312 |
+
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
313 |
+
|
314 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
315 |
+
|
316 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
317 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
318 |
+
for layer_idx in range(len(self.key_cache)):
|
319 |
+
device = self.key_cache[layer_idx].device
|
320 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
321 |
+
device = self.value_cache[layer_idx].device
|
322 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
configuration_phi.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Phi model configuration"""
|
17 |
+
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.utils import logging
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
26 |
+
"microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json",
|
27 |
+
"microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json",
|
28 |
+
"microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class PhiConfig(PretrainedConfig):
|
33 |
+
r"""
|
34 |
+
This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi
|
35 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
36 |
+
defaults will yield a similar configuration to that of the Phi
|
37 |
+
[microsoft/phi-1](https://huggingface.co/microsoft/phi-1).
|
38 |
+
|
39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
40 |
+
documentation from [`PretrainedConfig`] for more information.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
vocab_size (`int`, *optional*, defaults to 51200):
|
44 |
+
Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the
|
45 |
+
`inputs_ids` passed when calling [`PhiModel`].
|
46 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
47 |
+
Dimension of the hidden representations.
|
48 |
+
intermediate_size (`int`, *optional*, defaults to 8192):
|
49 |
+
Dimension of the MLP representations.
|
50 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
51 |
+
Number of hidden layers in the Transformer decoder.
|
52 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
53 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
54 |
+
num_key_value_heads (`int`, *optional*):
|
55 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
56 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
57 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
58 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
59 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
60 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
61 |
+
`num_attention_heads`.
|
62 |
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
63 |
+
Dropout probability for mlp outputs.
|
64 |
+
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
65 |
+
The dropout ratio for the embeddings.
|
66 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
+
The dropout ratio after computing the attention scores.
|
68 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
|
69 |
+
The non-linear activation function (function or string) in the decoder.
|
70 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
71 |
+
The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
|
72 |
+
tokens.
|
73 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
74 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
75 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
76 |
+
The epsilon used by the rms normalization layers.
|
77 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
78 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
79 |
+
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
|
80 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
81 |
+
Whether to tie weight embeddings
|
82 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
83 |
+
The base period of the RoPE embeddings.
|
84 |
+
rope_scaling (`Dict`, *optional*):
|
85 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
86 |
+
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
87 |
+
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
88 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
89 |
+
these scaling strategies behave:
|
90 |
+
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
|
91 |
+
is an experimental feature, subject to breaking API changes in future versions.
|
92 |
+
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
93 |
+
Percentage of the query and keys which will have rotary embedding.
|
94 |
+
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
95 |
+
Whether or not to normalize the Queries and Keys after projecting the hidden states.
|
96 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
97 |
+
Denotes beginning of sequences token id.
|
98 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
99 |
+
Denotes end of sequences token id.
|
100 |
+
|
101 |
+
Example:
|
102 |
+
|
103 |
+
```python
|
104 |
+
>>> from transformers import PhiModel, PhiConfig
|
105 |
+
|
106 |
+
>>> # Initializing a Phi-1 style configuration
|
107 |
+
>>> configuration = PhiConfig.from_pretrained("microsoft/phi-1")
|
108 |
+
|
109 |
+
>>> # Initializing a model from the configuration
|
110 |
+
>>> model = PhiModel(configuration)
|
111 |
+
|
112 |
+
>>> # Accessing the model configuration
|
113 |
+
>>> configuration = model.config
|
114 |
+
```"""
|
115 |
+
|
116 |
+
model_type = "phi"
|
117 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
vocab_size=51200,
|
122 |
+
hidden_size=2048,
|
123 |
+
intermediate_size=8192,
|
124 |
+
num_hidden_layers=24,
|
125 |
+
num_attention_heads=32,
|
126 |
+
num_key_value_heads=None,
|
127 |
+
resid_pdrop=0.0,
|
128 |
+
embd_pdrop=0.0,
|
129 |
+
attention_dropout=0.0,
|
130 |
+
hidden_act="gelu_new",
|
131 |
+
max_position_embeddings=2048,
|
132 |
+
initializer_range=0.02,
|
133 |
+
layer_norm_eps=1e-5,
|
134 |
+
use_cache=True,
|
135 |
+
tie_word_embeddings=False,
|
136 |
+
rope_theta=10000.0,
|
137 |
+
rope_scaling=None,
|
138 |
+
partial_rotary_factor=0.5,
|
139 |
+
qk_layernorm=False,
|
140 |
+
bos_token_id=1,
|
141 |
+
eos_token_id=2,
|
142 |
+
**kwargs,
|
143 |
+
):
|
144 |
+
self.vocab_size = vocab_size
|
145 |
+
self.hidden_size = hidden_size
|
146 |
+
self.intermediate_size = intermediate_size
|
147 |
+
self.num_hidden_layers = num_hidden_layers
|
148 |
+
self.num_attention_heads = num_attention_heads
|
149 |
+
|
150 |
+
if num_key_value_heads is None:
|
151 |
+
num_key_value_heads = num_attention_heads
|
152 |
+
|
153 |
+
self.num_key_value_heads = num_key_value_heads
|
154 |
+
self.resid_pdrop = resid_pdrop
|
155 |
+
self.embd_pdrop = embd_pdrop
|
156 |
+
self.attention_dropout = attention_dropout
|
157 |
+
self.hidden_act = hidden_act
|
158 |
+
self.max_position_embeddings = max_position_embeddings
|
159 |
+
self.initializer_range = initializer_range
|
160 |
+
self.layer_norm_eps = layer_norm_eps
|
161 |
+
self.use_cache = use_cache
|
162 |
+
self.rope_theta = rope_theta
|
163 |
+
self.rope_scaling = rope_scaling
|
164 |
+
self.partial_rotary_factor = partial_rotary_factor
|
165 |
+
self.qk_layernorm = qk_layernorm
|
166 |
+
self._rope_scaling_validation()
|
167 |
+
|
168 |
+
super().__init__(
|
169 |
+
bos_token_id=bos_token_id,
|
170 |
+
eos_token_id=eos_token_id,
|
171 |
+
tie_word_embeddings=tie_word_embeddings,
|
172 |
+
**kwargs,
|
173 |
+
)
|
174 |
+
|
175 |
+
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
176 |
+
def _rope_scaling_validation(self):
|
177 |
+
"""
|
178 |
+
Validate the `rope_scaling` configuration.
|
179 |
+
"""
|
180 |
+
if self.rope_scaling is None:
|
181 |
+
return
|
182 |
+
|
183 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
184 |
+
raise ValueError(
|
185 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
186 |
+
f"got {self.rope_scaling}"
|
187 |
+
)
|
188 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
189 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
190 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
191 |
+
raise ValueError(
|
192 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
193 |
+
)
|
194 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
195 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
modeling_attn_mask_utils.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class AttentionMaskConverter:
|
22 |
+
"""
|
23 |
+
A utility attention mask class that allows one to:
|
24 |
+
- Create a causal 4d mask
|
25 |
+
- Create a causal 4d mask with slided window
|
26 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
27 |
+
key_value_length) that can be multiplied with attention scores
|
28 |
+
|
29 |
+
Examples:
|
30 |
+
|
31 |
+
```python
|
32 |
+
>>> import torch
|
33 |
+
>>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
34 |
+
|
35 |
+
>>> converter = AttentionMaskConverter(True)
|
36 |
+
>>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
|
37 |
+
tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
38 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
39 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
|
40 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
|
41 |
+
[-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
|
42 |
+
```
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
is_causal (`bool`):
|
46 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
47 |
+
|
48 |
+
sliding_window (`int`, *optional*):
|
49 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
50 |
+
"""
|
51 |
+
|
52 |
+
is_causal: bool
|
53 |
+
sliding_window: int
|
54 |
+
|
55 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
56 |
+
self.is_causal = is_causal
|
57 |
+
self.sliding_window = sliding_window
|
58 |
+
|
59 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
60 |
+
raise ValueError(
|
61 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
62 |
+
)
|
63 |
+
|
64 |
+
def to_causal_4d(
|
65 |
+
self,
|
66 |
+
batch_size: int,
|
67 |
+
query_length: int,
|
68 |
+
key_value_length: int,
|
69 |
+
dtype: torch.dtype,
|
70 |
+
device: Union[torch.device, "str"] = "cpu",
|
71 |
+
) -> Optional[torch.Tensor]:
|
72 |
+
"""
|
73 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
74 |
+
bias to upper right hand triangular matrix (causal mask).
|
75 |
+
"""
|
76 |
+
if not self.is_causal:
|
77 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
78 |
+
|
79 |
+
# If shape is not cached, create a new causal mask and cache it
|
80 |
+
input_shape = (batch_size, query_length)
|
81 |
+
past_key_values_length = key_value_length - query_length
|
82 |
+
|
83 |
+
# create causal mask
|
84 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
85 |
+
causal_4d_mask = None
|
86 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
87 |
+
causal_4d_mask = self._make_causal_mask(
|
88 |
+
input_shape,
|
89 |
+
dtype,
|
90 |
+
device=device,
|
91 |
+
past_key_values_length=past_key_values_length,
|
92 |
+
sliding_window=self.sliding_window,
|
93 |
+
)
|
94 |
+
|
95 |
+
return causal_4d_mask
|
96 |
+
|
97 |
+
def to_4d(
|
98 |
+
self,
|
99 |
+
attention_mask_2d: torch.Tensor,
|
100 |
+
query_length: int,
|
101 |
+
dtype: torch.dtype,
|
102 |
+
key_value_length: Optional[int] = None,
|
103 |
+
) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
106 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
107 |
+
causal, a causal mask will be added.
|
108 |
+
"""
|
109 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
110 |
+
|
111 |
+
# create causal mask
|
112 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
113 |
+
causal_4d_mask = None
|
114 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
115 |
+
if key_value_length is None:
|
116 |
+
raise ValueError(
|
117 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
118 |
+
)
|
119 |
+
|
120 |
+
past_key_values_length = key_value_length - query_length
|
121 |
+
causal_4d_mask = self._make_causal_mask(
|
122 |
+
input_shape,
|
123 |
+
dtype,
|
124 |
+
device=attention_mask_2d.device,
|
125 |
+
past_key_values_length=past_key_values_length,
|
126 |
+
sliding_window=self.sliding_window,
|
127 |
+
)
|
128 |
+
elif self.sliding_window is not None:
|
129 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
130 |
+
|
131 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
132 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
133 |
+
attention_mask_2d.device
|
134 |
+
)
|
135 |
+
|
136 |
+
if causal_4d_mask is not None:
|
137 |
+
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
|
138 |
+
|
139 |
+
# expanded_attn_mask + causal_4d_mask can cause some overflow
|
140 |
+
expanded_4d_mask = expanded_attn_mask
|
141 |
+
|
142 |
+
return expanded_4d_mask
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def _make_causal_mask(
|
146 |
+
input_ids_shape: torch.Size,
|
147 |
+
dtype: torch.dtype,
|
148 |
+
device: torch.device,
|
149 |
+
past_key_values_length: int = 0,
|
150 |
+
sliding_window: Optional[int] = None,
|
151 |
+
):
|
152 |
+
"""
|
153 |
+
Make causal mask used for bi-directional self-attention.
|
154 |
+
"""
|
155 |
+
bsz, tgt_len = input_ids_shape
|
156 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
157 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
158 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
159 |
+
|
160 |
+
mask = mask.to(dtype)
|
161 |
+
|
162 |
+
if past_key_values_length > 0:
|
163 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
164 |
+
|
165 |
+
# add lower triangular sliding window mask if necessary
|
166 |
+
if sliding_window is not None:
|
167 |
+
diagonal = past_key_values_length - sliding_window + 1
|
168 |
+
|
169 |
+
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
170 |
+
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
171 |
+
|
172 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
173 |
+
|
174 |
+
@staticmethod
|
175 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
176 |
+
"""
|
177 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
178 |
+
"""
|
179 |
+
bsz, src_len = mask.size()
|
180 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
181 |
+
|
182 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
183 |
+
|
184 |
+
inverted_mask = 1.0 - expanded_mask
|
185 |
+
|
186 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def _unmask_unattended(
|
190 |
+
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
|
191 |
+
):
|
192 |
+
# fmt: off
|
193 |
+
"""
|
194 |
+
Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
|
195 |
+
using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
196 |
+
Details: https://github.com/pytorch/pytorch/issues/110213
|
197 |
+
|
198 |
+
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
|
199 |
+
`attention_mask` is [bsz, src_seq_len].
|
200 |
+
|
201 |
+
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
|
202 |
+
|
203 |
+
For example, if `attention_mask` is
|
204 |
+
```
|
205 |
+
[[0, 0, 1],
|
206 |
+
[1, 1, 1],
|
207 |
+
[0, 1, 1]]
|
208 |
+
```
|
209 |
+
and `expanded_mask` is (e.g. here left-padding case)
|
210 |
+
```
|
211 |
+
[[[[0, 0, 0],
|
212 |
+
[0, 0, 0],
|
213 |
+
[0, 0, 1]]],
|
214 |
+
[[[1, 0, 0],
|
215 |
+
[1, 1, 0],
|
216 |
+
[1, 1, 1]]],
|
217 |
+
[[[0, 0, 0],
|
218 |
+
[0, 1, 0],
|
219 |
+
[0, 1, 1]]]]
|
220 |
+
```
|
221 |
+
then the modified `expanded_mask` will be
|
222 |
+
```
|
223 |
+
[[[[1, 1, 1], <-- modified
|
224 |
+
[1, 1, 1], <-- modified
|
225 |
+
[0, 0, 1]]],
|
226 |
+
[[[1, 0, 0],
|
227 |
+
[1, 1, 0],
|
228 |
+
[1, 1, 1]]],
|
229 |
+
[[[1, 1, 1], <-- modified
|
230 |
+
[0, 1, 0],
|
231 |
+
[0, 1, 1]]]]
|
232 |
+
```
|
233 |
+
"""
|
234 |
+
# fmt: on
|
235 |
+
|
236 |
+
# Get the index of the first non-zero value for every sample in the batch.
|
237 |
+
# In the above example, indices = [[2], [0], [1]]]
|
238 |
+
tmp = torch.arange(attention_mask.shape[1], 0, -1)
|
239 |
+
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
|
240 |
+
|
241 |
+
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
|
242 |
+
# expanded mask will be completely unattended.
|
243 |
+
left_masked_rows = torch.where(indices > 0)[0]
|
244 |
+
|
245 |
+
if left_masked_rows.shape[0] == 0:
|
246 |
+
return expanded_mask
|
247 |
+
indices = indices[left_masked_rows]
|
248 |
+
|
249 |
+
max_len = torch.max(indices)
|
250 |
+
range_tensor = torch.arange(max_len).unsqueeze(0)
|
251 |
+
range_tensor = range_tensor.repeat(indices.size(0), 1)
|
252 |
+
|
253 |
+
# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
|
254 |
+
range_tensor[range_tensor >= indices] = 0
|
255 |
+
|
256 |
+
# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
|
257 |
+
if expanded_mask.dim() == 4:
|
258 |
+
num_masks = expanded_mask.shape[1]
|
259 |
+
if num_masks == 1:
|
260 |
+
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
261 |
+
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
|
262 |
+
else:
|
263 |
+
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
|
264 |
+
mask_slice = (
|
265 |
+
left_masked_rows[:, None, None],
|
266 |
+
torch.arange(num_masks)[None, :, None],
|
267 |
+
range_tensor[:, None, :],
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
|
271 |
+
mask_slice = (left_masked_rows[:, None], range_tensor)
|
272 |
+
|
273 |
+
expanded_mask[mask_slice] = unmasked_value
|
274 |
+
|
275 |
+
return expanded_mask
|
276 |
+
|
277 |
+
|
278 |
+
def _prepare_4d_causal_attention_mask(
|
279 |
+
attention_mask: Optional[torch.Tensor],
|
280 |
+
input_shape: Union[torch.Size, Tuple, List],
|
281 |
+
inputs_embeds: torch.Tensor,
|
282 |
+
past_key_values_length: int,
|
283 |
+
sliding_window: Optional[int] = None,
|
284 |
+
):
|
285 |
+
"""
|
286 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
287 |
+
`(batch_size, key_value_length)`
|
288 |
+
|
289 |
+
Args:
|
290 |
+
attention_mask (`torch.Tensor` or `None`):
|
291 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
292 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
293 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
294 |
+
inputs_embeds (`torch.Tensor`):
|
295 |
+
The embedded inputs as a torch Tensor.
|
296 |
+
past_key_values_length (`int`):
|
297 |
+
The length of the key value cache.
|
298 |
+
sliding_window (`int`, *optional*):
|
299 |
+
If the model uses windowed attention, a sliding window should be passed.
|
300 |
+
"""
|
301 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
302 |
+
|
303 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
304 |
+
|
305 |
+
# 4d mask is passed through the layers
|
306 |
+
if attention_mask is not None and len(attention_mask.shape) == 2:
|
307 |
+
attention_mask = attn_mask_converter.to_4d(
|
308 |
+
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
|
309 |
+
)
|
310 |
+
elif attention_mask is not None and len(attention_mask.shape) == 4:
|
311 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
312 |
+
if tuple(attention_mask.shape) != expected_shape:
|
313 |
+
raise ValueError(
|
314 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
318 |
+
inverted_mask = 1.0 - attention_mask
|
319 |
+
attention_mask = inverted_mask.masked_fill(
|
320 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
324 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
325 |
+
)
|
326 |
+
|
327 |
+
return attention_mask
|
328 |
+
|
329 |
+
|
330 |
+
# Adapted from _prepare_4d_causal_attention_mask
|
331 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
332 |
+
attention_mask: Optional[torch.Tensor],
|
333 |
+
input_shape: Union[torch.Size, Tuple, List],
|
334 |
+
inputs_embeds: torch.Tensor,
|
335 |
+
past_key_values_length: int,
|
336 |
+
sliding_window: Optional[int] = None,
|
337 |
+
):
|
338 |
+
"""
|
339 |
+
Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
|
340 |
+
|
341 |
+
In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
|
342 |
+
`key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
|
343 |
+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
|
344 |
+
"""
|
345 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
346 |
+
|
347 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
348 |
+
batch_size, query_length = input_shape
|
349 |
+
|
350 |
+
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
351 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
352 |
+
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
353 |
+
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)
|
354 |
+
|
355 |
+
if attention_mask is not None:
|
356 |
+
# 4d mask is passed through
|
357 |
+
if len(attention_mask.shape) == 4:
|
358 |
+
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
|
359 |
+
if tuple(attention_mask.shape) != expected_shape:
|
360 |
+
raise ValueError(
|
361 |
+
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
|
362 |
+
)
|
363 |
+
else:
|
364 |
+
# if the 4D mask has correct shape - invert it and fill with negative infinity
|
365 |
+
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
|
366 |
+
attention_mask = inverted_mask.masked_fill(
|
367 |
+
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
|
368 |
+
)
|
369 |
+
return attention_mask
|
370 |
+
|
371 |
+
elif not is_tracing and torch.all(attention_mask == 1):
|
372 |
+
if query_length == 1:
|
373 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
374 |
+
attention_mask = None
|
375 |
+
elif key_value_length == query_length:
|
376 |
+
attention_mask = None
|
377 |
+
else:
|
378 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
|
379 |
+
# may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
380 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
381 |
+
pass
|
382 |
+
elif query_length > 1 and key_value_length != query_length:
|
383 |
+
# See the comment above (https://github.com/pytorch/pytorch/issues/108108).
|
384 |
+
# Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
|
385 |
+
attention_mask = True
|
386 |
+
elif is_tracing:
|
387 |
+
raise ValueError(
|
388 |
+
'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
|
389 |
+
)
|
390 |
+
|
391 |
+
if attention_mask is None:
|
392 |
+
expanded_4d_mask = None
|
393 |
+
elif attention_mask is True:
|
394 |
+
expanded_4d_mask = attn_mask_converter.to_causal_4d(
|
395 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
expanded_4d_mask = attn_mask_converter.to_4d(
|
399 |
+
attention_mask,
|
400 |
+
input_shape[-1],
|
401 |
+
dtype=inputs_embeds.dtype,
|
402 |
+
key_value_length=key_value_length,
|
403 |
+
)
|
404 |
+
|
405 |
+
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
406 |
+
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
407 |
+
#
|
408 |
+
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
|
409 |
+
# controlflow that can not be captured properly.
|
410 |
+
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
|
411 |
+
if query_length > 1 and not is_tracing:
|
412 |
+
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
413 |
+
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
414 |
+
)
|
415 |
+
|
416 |
+
return expanded_4d_mask
|
417 |
+
|
418 |
+
|
419 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
420 |
+
"""
|
421 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
422 |
+
`(batch_size, key_value_length)`
|
423 |
+
|
424 |
+
Args:
|
425 |
+
mask (`torch.Tensor` or `None`):
|
426 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
427 |
+
dtype (`torch.dtype`):
|
428 |
+
The torch dtype the created mask shall have.
|
429 |
+
tgt_len (`int`):
|
430 |
+
The target length or query length the created mask shall have.
|
431 |
+
"""
|
432 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
433 |
+
|
434 |
+
|
435 |
+
def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
436 |
+
"""
|
437 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
438 |
+
`(batch_size, key_value_length)`
|
439 |
+
|
440 |
+
Args:
|
441 |
+
mask (`torch.Tensor` or `None`):
|
442 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
443 |
+
dtype (`torch.dtype`):
|
444 |
+
The torch dtype the created mask shall have.
|
445 |
+
tgt_len (`int`):
|
446 |
+
The target length or query length the created mask shall have.
|
447 |
+
"""
|
448 |
+
batch_size, key_value_length = mask.shape
|
449 |
+
tgt_len = tgt_len if tgt_len is not None else key_value_length
|
450 |
+
|
451 |
+
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
452 |
+
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
453 |
+
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
454 |
+
is_tracing = torch.jit.is_tracing()
|
455 |
+
|
456 |
+
if torch.all(mask == 1):
|
457 |
+
if is_tracing:
|
458 |
+
pass
|
459 |
+
elif tgt_len == 1:
|
460 |
+
# For query_length == 1, causal attention and bi-directional attention are the same.
|
461 |
+
return None
|
462 |
+
elif key_value_length == tgt_len:
|
463 |
+
return None
|
464 |
+
else:
|
465 |
+
# Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
|
466 |
+
# may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
|
467 |
+
# Reference: https://github.com/pytorch/pytorch/issues/108108
|
468 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
469 |
+
else:
|
470 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
471 |
+
|
472 |
+
|
473 |
+
def _create_4d_causal_attention_mask(
|
474 |
+
input_shape: Union[torch.Size, Tuple, List],
|
475 |
+
dtype: torch.dtype,
|
476 |
+
device: torch.device,
|
477 |
+
past_key_values_length: int = 0,
|
478 |
+
sliding_window: Optional[int] = None,
|
479 |
+
) -> Optional[torch.Tensor]:
|
480 |
+
"""
|
481 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
482 |
+
|
483 |
+
Args:
|
484 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
485 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
486 |
+
dtype (`torch.dtype`):
|
487 |
+
The torch dtype the created mask shall have.
|
488 |
+
device (`int`):
|
489 |
+
The torch device the created mask shall have.
|
490 |
+
sliding_window (`int`, *optional*):
|
491 |
+
If the model uses windowed attention, a sliding window should be passed.
|
492 |
+
"""
|
493 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
494 |
+
|
495 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
496 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
497 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
498 |
+
)
|
499 |
+
|
500 |
+
return attention_mask
|
modeling_phi.py
ADDED
@@ -0,0 +1,836 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" PyTorch Phi model."""
|
17 |
+
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import List, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import CrossEntropyLoss
|
27 |
+
|
28 |
+
from transformers.activations import ACT2FN
|
29 |
+
from cache_utils import Cache, DynamicCache
|
30 |
+
from modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
31 |
+
from transformers.modeling_outputs import (
|
32 |
+
BaseModelOutputWithPast,
|
33 |
+
CausalLMOutputWithPast,
|
34 |
+
)
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.utils import (
|
37 |
+
add_start_docstrings,
|
38 |
+
add_start_docstrings_to_model_forward,
|
39 |
+
logging,
|
40 |
+
replace_return_docstrings,
|
41 |
+
)
|
42 |
+
|
43 |
+
from configuration_phi import PhiConfig
|
44 |
+
|
45 |
+
logger = logging.get_logger(__name__)
|
46 |
+
|
47 |
+
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
48 |
+
_CONFIG_FOR_DOC = "PhiConfig"
|
49 |
+
|
50 |
+
PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
51 |
+
"microsoft/phi-1",
|
52 |
+
"microsoft/phi-1_5",
|
53 |
+
"microsoft/phi-2",
|
54 |
+
# See all Phi models at https://huggingface.co/models?filter=phi
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
59 |
+
def _get_unpad_data(attention_mask):
|
60 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
61 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
62 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
63 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
64 |
+
return (
|
65 |
+
indices,
|
66 |
+
cu_seqlens,
|
67 |
+
max_seqlen_in_batch,
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
72 |
+
class PhiRotaryEmbedding(nn.Module):
|
73 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.dim = dim
|
77 |
+
self.max_position_embeddings = max_position_embeddings
|
78 |
+
self.base = base
|
79 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
80 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
81 |
+
|
82 |
+
# Build here to make `torch.jit.trace` work.
|
83 |
+
self._set_cos_sin_cache(
|
84 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
85 |
+
)
|
86 |
+
|
87 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
88 |
+
self.max_seq_len_cached = seq_len
|
89 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
90 |
+
|
91 |
+
freqs = torch.outer(t, self.inv_freq)
|
92 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
93 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
94 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
95 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
96 |
+
|
97 |
+
def forward(self, x, seq_len=None):
|
98 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
99 |
+
if seq_len > self.max_seq_len_cached:
|
100 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
101 |
+
|
102 |
+
return (
|
103 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
104 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
|
109 |
+
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
110 |
+
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
111 |
+
|
112 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
113 |
+
self.scaling_factor = scaling_factor
|
114 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
115 |
+
|
116 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
117 |
+
self.max_seq_len_cached = seq_len
|
118 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
119 |
+
t = t / self.scaling_factor
|
120 |
+
|
121 |
+
freqs = torch.outer(t, self.inv_freq)
|
122 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
123 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
124 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
125 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
126 |
+
|
127 |
+
|
128 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
|
129 |
+
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
130 |
+
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
131 |
+
|
132 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
133 |
+
self.scaling_factor = scaling_factor
|
134 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
135 |
+
|
136 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
137 |
+
self.max_seq_len_cached = seq_len
|
138 |
+
|
139 |
+
if seq_len > self.max_position_embeddings:
|
140 |
+
base = self.base * (
|
141 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
142 |
+
) ** (self.dim / (self.dim - 2))
|
143 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
144 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
145 |
+
|
146 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
147 |
+
|
148 |
+
freqs = torch.outer(t, self.inv_freq)
|
149 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
150 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
151 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
152 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
153 |
+
|
154 |
+
|
155 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
156 |
+
def rotate_half(x):
|
157 |
+
"""Rotates half the hidden dims of the input."""
|
158 |
+
x1 = x[..., : x.shape[-1] // 2]
|
159 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
160 |
+
return torch.cat((-x2, x1), dim=-1)
|
161 |
+
|
162 |
+
|
163 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
164 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
165 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
q (`torch.Tensor`): The query tensor.
|
169 |
+
k (`torch.Tensor`): The key tensor.
|
170 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
171 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
172 |
+
position_ids (`torch.Tensor`):
|
173 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
174 |
+
used to pass offsetted position ids when working with a KV-cache.
|
175 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
176 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
177 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
178 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
179 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
180 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
181 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
182 |
+
Returns:
|
183 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
184 |
+
"""
|
185 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
186 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
187 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
188 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
189 |
+
return q_embed, k_embed
|
190 |
+
|
191 |
+
|
192 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
|
193 |
+
class PhiMLP(nn.Module):
|
194 |
+
def __init__(self, config):
|
195 |
+
super().__init__()
|
196 |
+
self.config = config
|
197 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
198 |
+
#############################################################################################################################
|
199 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
200 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
201 |
+
#############################################################################################################################
|
202 |
+
|
203 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
204 |
+
hidden_states = self.fc1(hidden_states)
|
205 |
+
hidden_states = self.activation_fn(hidden_states)
|
206 |
+
hidden_states = self.fc2(hidden_states)
|
207 |
+
return hidden_states
|
208 |
+
|
209 |
+
|
210 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
211 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
212 |
+
"""
|
213 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
214 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
215 |
+
"""
|
216 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
217 |
+
if n_rep == 1:
|
218 |
+
return hidden_states
|
219 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
220 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
221 |
+
|
222 |
+
|
223 |
+
class PhiAttention(nn.Module):
|
224 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
225 |
+
|
226 |
+
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
|
227 |
+
super().__init__()
|
228 |
+
self.config = config
|
229 |
+
self.layer_idx = layer_idx
|
230 |
+
if layer_idx is None:
|
231 |
+
logger.warning_once(
|
232 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
233 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
234 |
+
"when creating this class."
|
235 |
+
)
|
236 |
+
|
237 |
+
self.attention_dropout = config.attention_dropout
|
238 |
+
self.hidden_size = config.hidden_size
|
239 |
+
self.num_heads = config.num_attention_heads
|
240 |
+
self.head_dim = self.hidden_size // self.num_heads
|
241 |
+
self.num_key_value_heads = config.num_key_value_heads
|
242 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
243 |
+
self.max_position_embeddings = config.max_position_embeddings
|
244 |
+
self.rope_theta = config.rope_theta
|
245 |
+
self.partial_rotary_factor = config.partial_rotary_factor
|
246 |
+
self.is_causal = True
|
247 |
+
|
248 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
249 |
+
raise ValueError(
|
250 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
251 |
+
f" and `num_heads`: {self.num_heads})."
|
252 |
+
)
|
253 |
+
#############################################################################################################################
|
254 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
255 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
256 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
257 |
+
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
|
258 |
+
#############################################################################################################################
|
259 |
+
self.qk_layernorm = config.qk_layernorm
|
260 |
+
if self.qk_layernorm:
|
261 |
+
self.q_layernorm = nn.LayerNorm(
|
262 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
263 |
+
)
|
264 |
+
self.k_layernorm = nn.LayerNorm(
|
265 |
+
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
266 |
+
)
|
267 |
+
|
268 |
+
self._init_rope()
|
269 |
+
|
270 |
+
def _init_rope(self):
|
271 |
+
if self.config.rope_scaling is None:
|
272 |
+
self.rotary_emb = PhiRotaryEmbedding(
|
273 |
+
int(self.partial_rotary_factor * self.head_dim),
|
274 |
+
max_position_embeddings=self.max_position_embeddings,
|
275 |
+
base=self.rope_theta,
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
scaling_type = self.config.rope_scaling["type"]
|
279 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
280 |
+
if scaling_type == "linear":
|
281 |
+
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
|
282 |
+
int(self.partial_rotary_factor * self.head_dim),
|
283 |
+
max_position_embeddings=self.max_position_embeddings,
|
284 |
+
scaling_factor=scaling_factor,
|
285 |
+
base=self.rope_theta,
|
286 |
+
)
|
287 |
+
elif scaling_type == "dynamic":
|
288 |
+
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
|
289 |
+
int(self.partial_rotary_factor * self.head_dim),
|
290 |
+
max_position_embeddings=self.max_position_embeddings,
|
291 |
+
scaling_factor=scaling_factor,
|
292 |
+
base=self.rope_theta,
|
293 |
+
)
|
294 |
+
else:
|
295 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
296 |
+
|
297 |
+
def forward(
|
298 |
+
self,
|
299 |
+
hidden_states: torch.Tensor,
|
300 |
+
attention_mask: Optional[torch.Tensor] = None,
|
301 |
+
position_ids: Optional[torch.LongTensor] = None,
|
302 |
+
past_key_value: Optional[Cache] = None,
|
303 |
+
output_attentions: bool = False,
|
304 |
+
use_cache: bool = False,
|
305 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
306 |
+
bsz, q_len, _ = hidden_states.size()
|
307 |
+
|
308 |
+
query_states = self.q_proj(hidden_states)
|
309 |
+
key_states = self.k_proj(hidden_states)
|
310 |
+
value_states = self.v_proj(hidden_states)
|
311 |
+
|
312 |
+
if self.qk_layernorm:
|
313 |
+
query_states = self.q_layernorm(query_states)
|
314 |
+
key_states = self.k_layernorm(key_states)
|
315 |
+
|
316 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
317 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
318 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
319 |
+
|
320 |
+
kv_seq_len = key_states.shape[-2]
|
321 |
+
if past_key_value is not None:
|
322 |
+
if self.layer_idx is None:
|
323 |
+
raise ValueError(
|
324 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
325 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
326 |
+
"with a layer index."
|
327 |
+
)
|
328 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
329 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
330 |
+
|
331 |
+
# Partial rotary embedding
|
332 |
+
query_rot, query_pass = (
|
333 |
+
query_states[..., : self.rotary_emb.dim],
|
334 |
+
query_states[..., self.rotary_emb.dim :],
|
335 |
+
)
|
336 |
+
key_rot, key_pass = (
|
337 |
+
key_states[..., : self.rotary_emb.dim],
|
338 |
+
key_states[..., self.rotary_emb.dim :],
|
339 |
+
)
|
340 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
341 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
342 |
+
|
343 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
344 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
345 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
346 |
+
|
347 |
+
if past_key_value is not None:
|
348 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
349 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
350 |
+
|
351 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
352 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
353 |
+
#############################################################################################################################
|
354 |
+
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
355 |
+
attn_weights = torch.matmul(
|
356 |
+
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
357 |
+
) / math.sqrt(self.head_dim)
|
358 |
+
#############################################################################################################################
|
359 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
360 |
+
raise ValueError(
|
361 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
362 |
+
f" {attn_weights.size()}"
|
363 |
+
)
|
364 |
+
|
365 |
+
if attention_mask is not None:
|
366 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
367 |
+
raise ValueError(
|
368 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
369 |
+
)
|
370 |
+
attn_weights = attn_weights + attention_mask
|
371 |
+
|
372 |
+
# upcast attention to fp32
|
373 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
374 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
375 |
+
#############################################################################################################################
|
376 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
377 |
+
#############################################################################################################################
|
378 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
379 |
+
raise ValueError(
|
380 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
381 |
+
f" {attn_output.size()}"
|
382 |
+
)
|
383 |
+
|
384 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
385 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
386 |
+
|
387 |
+
attn_output = self.dense(attn_output)
|
388 |
+
|
389 |
+
if not output_attentions:
|
390 |
+
attn_weights = None
|
391 |
+
|
392 |
+
return attn_output, attn_weights, past_key_value
|
393 |
+
|
394 |
+
|
395 |
+
class PhiDecoderLayer(nn.Module):
|
396 |
+
def __init__(self, config: PhiConfig, layer_idx: int):
|
397 |
+
super().__init__()
|
398 |
+
self.self_attn = PhiAttention(config, layer_idx=layer_idx)
|
399 |
+
self.mlp = PhiMLP(config)
|
400 |
+
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
401 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
402 |
+
|
403 |
+
def forward(
|
404 |
+
self,
|
405 |
+
hidden_states: torch.Tensor,
|
406 |
+
attention_mask: Optional[torch.Tensor] = None,
|
407 |
+
position_ids: Optional[torch.LongTensor] = None,
|
408 |
+
output_attentions: Optional[bool] = False,
|
409 |
+
use_cache: Optional[bool] = False,
|
410 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
411 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
412 |
+
"""
|
413 |
+
Args:
|
414 |
+
hidden_states (`torch.FloatTensor`):
|
415 |
+
input to the layer of shape `(batch, seq_len, embed_dim)`
|
416 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
417 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
418 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
419 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
420 |
+
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
421 |
+
output_attentions (`bool`, *optional*):
|
422 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
423 |
+
returned tensors for more detail.
|
424 |
+
use_cache (`bool`, *optional*):
|
425 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
426 |
+
(see `past_key_values`).
|
427 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
428 |
+
"""
|
429 |
+
|
430 |
+
residual = hidden_states
|
431 |
+
|
432 |
+
hidden_states = self.input_layernorm(hidden_states)
|
433 |
+
|
434 |
+
# Self Attention
|
435 |
+
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
|
436 |
+
hidden_states=hidden_states,
|
437 |
+
attention_mask=attention_mask,
|
438 |
+
position_ids=position_ids,
|
439 |
+
past_key_value=past_key_value,
|
440 |
+
output_attentions=output_attentions,
|
441 |
+
use_cache=use_cache,
|
442 |
+
)
|
443 |
+
attn_outputs = self.resid_dropout(attn_outputs)
|
444 |
+
|
445 |
+
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
446 |
+
#############################################################################################################################
|
447 |
+
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
448 |
+
#############################################################################################################################
|
449 |
+
outputs = (hidden_states,)
|
450 |
+
|
451 |
+
if output_attentions:
|
452 |
+
outputs += (self_attn_weights,)
|
453 |
+
|
454 |
+
if use_cache:
|
455 |
+
outputs += (present_key_value,)
|
456 |
+
|
457 |
+
return outputs
|
458 |
+
|
459 |
+
|
460 |
+
PHI_START_DOCSTRING = r"""
|
461 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
462 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
463 |
+
etc.)
|
464 |
+
|
465 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
466 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
467 |
+
and behavior.
|
468 |
+
|
469 |
+
Parameters:
|
470 |
+
config ([`PhiConfig`]):
|
471 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
472 |
+
load the weights associated with the model, only the configuration. Check out the
|
473 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
474 |
+
"""
|
475 |
+
|
476 |
+
|
477 |
+
@add_start_docstrings(
|
478 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
479 |
+
PHI_START_DOCSTRING,
|
480 |
+
)
|
481 |
+
class PhiPreTrainedModel(PreTrainedModel):
|
482 |
+
config_class = PhiConfig
|
483 |
+
base_model_prefix = "model"
|
484 |
+
supports_gradient_checkpointing = True
|
485 |
+
_no_split_modules = ["PhiDecoderLayer"]
|
486 |
+
_skip_keys_device_placement = "past_key_values"
|
487 |
+
_supports_flash_attn_2 = True
|
488 |
+
_supports_cache_class = True
|
489 |
+
|
490 |
+
def _init_weights(self, module):
|
491 |
+
std = self.config.initializer_range
|
492 |
+
if isinstance(module, nn.Linear):
|
493 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
494 |
+
if module.bias is not None:
|
495 |
+
module.bias.data.zero_()
|
496 |
+
elif isinstance(module, nn.Embedding):
|
497 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
498 |
+
if module.padding_idx is not None:
|
499 |
+
module.weight.data[module.padding_idx].zero_()
|
500 |
+
|
501 |
+
|
502 |
+
PHI_INPUTS_DOCSTRING = r"""
|
503 |
+
Args:
|
504 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
505 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
506 |
+
it.
|
507 |
+
|
508 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
509 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
510 |
+
|
511 |
+
[What are input IDs?](../glossary#input-ids)
|
512 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
513 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
514 |
+
|
515 |
+
- 1 for tokens that are **not masked**,
|
516 |
+
- 0 for tokens that are **masked**.
|
517 |
+
|
518 |
+
[What are attention masks?](../glossary#attention-mask)
|
519 |
+
|
520 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
521 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
522 |
+
|
523 |
+
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
524 |
+
`past_key_values`).
|
525 |
+
|
526 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
527 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
528 |
+
information on the default strategy.
|
529 |
+
|
530 |
+
- 1 indicates the head is **not masked**,
|
531 |
+
- 0 indicates the head is **masked**.
|
532 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
533 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
534 |
+
config.n_positions - 1]`.
|
535 |
+
|
536 |
+
[What are position IDs?](../glossary#position-ids)
|
537 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
538 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
539 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
540 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
541 |
+
|
542 |
+
Two formats are allowed:
|
543 |
+
- a [`~cache_utils.Cache`] instance;
|
544 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
545 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
546 |
+
cache format.
|
547 |
+
|
548 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
549 |
+
legacy cache format will be returned.
|
550 |
+
|
551 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
552 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
553 |
+
of shape `(batch_size, sequence_length)`.
|
554 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
555 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
556 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
557 |
+
model's internal embedding lookup matrix.
|
558 |
+
use_cache (`bool`, *optional*):
|
559 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
560 |
+
`past_key_values`).
|
561 |
+
output_attentions (`bool`, *optional*):
|
562 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
563 |
+
tensors for more detail.
|
564 |
+
output_hidden_states (`bool`, *optional*):
|
565 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
566 |
+
more detail.
|
567 |
+
return_dict (`bool`, *optional*):
|
568 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
569 |
+
"""
|
570 |
+
|
571 |
+
|
572 |
+
@add_start_docstrings(
|
573 |
+
"The bare Phi Model outputting raw hidden-states without any specific head on top.",
|
574 |
+
PHI_START_DOCSTRING,
|
575 |
+
)
|
576 |
+
class PhiModel(PhiPreTrainedModel):
|
577 |
+
"""
|
578 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
|
579 |
+
|
580 |
+
Args:
|
581 |
+
config: PhiConfig
|
582 |
+
"""
|
583 |
+
|
584 |
+
def __init__(self, config: PhiConfig):
|
585 |
+
super().__init__(config)
|
586 |
+
self.padding_idx = config.pad_token_id
|
587 |
+
self.vocab_size = config.vocab_size
|
588 |
+
|
589 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
590 |
+
self.embed_dropout = nn.Dropout(config.embd_pdrop)
|
591 |
+
self.layers = nn.ModuleList(
|
592 |
+
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
593 |
+
)
|
594 |
+
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
595 |
+
# self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
596 |
+
|
597 |
+
self.gradient_checkpointing = False
|
598 |
+
# Initialize weights and apply final processing
|
599 |
+
self.post_init()
|
600 |
+
|
601 |
+
def get_input_embeddings(self):
|
602 |
+
return self.embed_tokens
|
603 |
+
|
604 |
+
def set_input_embeddings(self, value):
|
605 |
+
self.embed_tokens = value
|
606 |
+
|
607 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
608 |
+
def forward(
|
609 |
+
self,
|
610 |
+
input_ids: torch.LongTensor = None,
|
611 |
+
attention_mask: Optional[torch.Tensor] = None,
|
612 |
+
position_ids: Optional[torch.LongTensor] = None,
|
613 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
614 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
615 |
+
use_cache: Optional[bool] = None,
|
616 |
+
output_attentions: Optional[bool] = None,
|
617 |
+
output_hidden_states: Optional[bool] = None,
|
618 |
+
return_dict: Optional[bool] = None,
|
619 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
620 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
621 |
+
output_hidden_states = (
|
622 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
623 |
+
)
|
624 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
625 |
+
|
626 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
627 |
+
|
628 |
+
# retrieve input_ids and inputs_embeds
|
629 |
+
if input_ids is not None and inputs_embeds is not None:
|
630 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
631 |
+
elif input_ids is not None:
|
632 |
+
batch_size, seq_length = input_ids.shape[:2]
|
633 |
+
elif inputs_embeds is not None:
|
634 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
635 |
+
else:
|
636 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
637 |
+
|
638 |
+
past_key_values_length = 0
|
639 |
+
|
640 |
+
if self.gradient_checkpointing and self.training:
|
641 |
+
if use_cache:
|
642 |
+
logger.warning_once(
|
643 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
644 |
+
)
|
645 |
+
use_cache = False
|
646 |
+
|
647 |
+
if use_cache:
|
648 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
649 |
+
if use_legacy_cache:
|
650 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
651 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
652 |
+
|
653 |
+
if position_ids is None:
|
654 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
655 |
+
position_ids = torch.arange(
|
656 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
657 |
+
)
|
658 |
+
position_ids = position_ids.unsqueeze(0)
|
659 |
+
|
660 |
+
if inputs_embeds is None:
|
661 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
662 |
+
|
663 |
+
inputs_embeds = self.embed_dropout(inputs_embeds)
|
664 |
+
|
665 |
+
# 4d mask is passed through the layers
|
666 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
667 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
668 |
+
)
|
669 |
+
|
670 |
+
hidden_states = inputs_embeds
|
671 |
+
|
672 |
+
# decoder layers
|
673 |
+
all_hidden_states = () if output_hidden_states else None
|
674 |
+
all_self_attns = () if output_attentions else None
|
675 |
+
next_decoder_cache = None
|
676 |
+
|
677 |
+
for decoder_layer in self.layers:
|
678 |
+
if output_hidden_states:
|
679 |
+
all_hidden_states += (hidden_states,)
|
680 |
+
|
681 |
+
if self.gradient_checkpointing and self.training:
|
682 |
+
layer_outputs = self._gradient_checkpointing_func(
|
683 |
+
decoder_layer.__call__,
|
684 |
+
hidden_states,
|
685 |
+
attention_mask,
|
686 |
+
position_ids,
|
687 |
+
past_key_values,
|
688 |
+
output_attentions,
|
689 |
+
)
|
690 |
+
else:
|
691 |
+
layer_outputs = decoder_layer(
|
692 |
+
hidden_states,
|
693 |
+
attention_mask=attention_mask,
|
694 |
+
position_ids=position_ids,
|
695 |
+
past_key_value=past_key_values,
|
696 |
+
output_attentions=output_attentions,
|
697 |
+
use_cache=use_cache,
|
698 |
+
)
|
699 |
+
|
700 |
+
hidden_states = layer_outputs[0]
|
701 |
+
|
702 |
+
if use_cache:
|
703 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
704 |
+
|
705 |
+
if output_attentions:
|
706 |
+
all_self_attns += (layer_outputs[1],)
|
707 |
+
|
708 |
+
hidden_states = self.final_layernorm(hidden_states)
|
709 |
+
|
710 |
+
# add hidden states from the last decoder layer
|
711 |
+
if output_hidden_states:
|
712 |
+
all_hidden_states += (hidden_states,)
|
713 |
+
|
714 |
+
next_cache = None
|
715 |
+
if use_cache:
|
716 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
717 |
+
if not return_dict:
|
718 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
719 |
+
return BaseModelOutputWithPast(
|
720 |
+
last_hidden_state=hidden_states,
|
721 |
+
past_key_values=next_cache,
|
722 |
+
hidden_states=all_hidden_states,
|
723 |
+
attentions=all_self_attns,
|
724 |
+
)
|
725 |
+
|
726 |
+
|
727 |
+
class PhiForCausalLM(PhiPreTrainedModel):
|
728 |
+
_tied_weights_keys = ["lm_head.weight"]
|
729 |
+
|
730 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
|
731 |
+
def __init__(self, config):
|
732 |
+
super().__init__(config)
|
733 |
+
self.model = PhiModel(config)
|
734 |
+
self.vocab_size = config.vocab_size
|
735 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
736 |
+
|
737 |
+
# Initialize weights and apply final processing
|
738 |
+
self.post_init()
|
739 |
+
|
740 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
741 |
+
def get_input_embeddings(self):
|
742 |
+
return self.model.embed_tokens
|
743 |
+
|
744 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
745 |
+
def set_input_embeddings(self, value):
|
746 |
+
self.model.embed_tokens = value
|
747 |
+
|
748 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
749 |
+
def get_output_embeddings(self):
|
750 |
+
return self.lm_head
|
751 |
+
|
752 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
753 |
+
def set_output_embeddings(self, new_embeddings):
|
754 |
+
self.lm_head = new_embeddings
|
755 |
+
|
756 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
757 |
+
def set_decoder(self, decoder):
|
758 |
+
self.model = decoder
|
759 |
+
|
760 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
761 |
+
def get_decoder(self):
|
762 |
+
return self.model
|
763 |
+
|
764 |
+
@add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
|
765 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
766 |
+
def forward(
|
767 |
+
self,
|
768 |
+
input_ids: torch.LongTensor = None,
|
769 |
+
attention_mask: Optional[torch.Tensor] = None,
|
770 |
+
position_ids: Optional[torch.LongTensor] = None,
|
771 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
772 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
773 |
+
labels: Optional[torch.LongTensor] = None,
|
774 |
+
use_cache: Optional[bool] = None,
|
775 |
+
output_attentions: Optional[bool] = None,
|
776 |
+
output_hidden_states: Optional[bool] = None,
|
777 |
+
return_dict: Optional[bool] = None,
|
778 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
779 |
+
"""
|
780 |
+
Args:
|
781 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
782 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
783 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
784 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
785 |
+
|
786 |
+
Returns:
|
787 |
+
"""
|
788 |
+
|
789 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
790 |
+
output_hidden_states = (
|
791 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
792 |
+
)
|
793 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
794 |
+
|
795 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
796 |
+
outputs = self.model(
|
797 |
+
input_ids=input_ids,
|
798 |
+
attention_mask=attention_mask,
|
799 |
+
position_ids=position_ids,
|
800 |
+
past_key_values=past_key_values,
|
801 |
+
inputs_embeds=inputs_embeds,
|
802 |
+
use_cache=use_cache,
|
803 |
+
output_attentions=output_attentions,
|
804 |
+
output_hidden_states=output_hidden_states,
|
805 |
+
return_dict=return_dict,
|
806 |
+
)
|
807 |
+
|
808 |
+
hidden_states = outputs[0]
|
809 |
+
logits = self.lm_head(hidden_states)
|
810 |
+
logits = logits.float()
|
811 |
+
|
812 |
+
loss = None
|
813 |
+
if labels is not None:
|
814 |
+
# Shift so that tokens < n predict n
|
815 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
816 |
+
shift_labels = labels[..., 1:].contiguous()
|
817 |
+
# Flatten the tokens
|
818 |
+
loss_fct = CrossEntropyLoss()
|
819 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
820 |
+
shift_labels = shift_labels.view(-1)
|
821 |
+
# Enable model parallelism
|
822 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
823 |
+
loss = loss_fct(shift_logits, shift_labels)
|
824 |
+
|
825 |
+
if not return_dict:
|
826 |
+
output = (logits,) + outputs[1:]
|
827 |
+
return (loss,) + output if loss is not None else output
|
828 |
+
|
829 |
+
return CausalLMOutputWithPast(
|
830 |
+
loss=loss,
|
831 |
+
logits=logits,
|
832 |
+
past_key_values=outputs.past_key_values,
|
833 |
+
hidden_states=outputs.hidden_states,
|
834 |
+
attentions=outputs.attentions,
|
835 |
+
)
|
836 |
+
|
tokenization_codegen.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Salesforce authors, The Open AI Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for CodeGen"""
|
16 |
+
|
17 |
+
|
18 |
+
import json
|
19 |
+
import os
|
20 |
+
from functools import lru_cache
|
21 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import regex as re
|
25 |
+
|
26 |
+
from transformers.utils import is_tf_available, is_torch_available, logging
|
27 |
+
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
if is_torch_available():
|
31 |
+
import torch
|
32 |
+
if is_tf_available():
|
33 |
+
import tensorflow as tf
|
34 |
+
|
35 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
VOCAB_FILES_NAMES = {
|
41 |
+
"vocab_file": "vocab.json",
|
42 |
+
"merges_file": "merges.txt",
|
43 |
+
}
|
44 |
+
|
45 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
46 |
+
"vocab_file": {
|
47 |
+
"Salesforce/codegen-350M-mono": "https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json",
|
48 |
+
},
|
49 |
+
"merges_file": {
|
50 |
+
"Salesforce/codegen-350M-mono": "https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt",
|
51 |
+
},
|
52 |
+
}
|
53 |
+
|
54 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
55 |
+
"Salesforce/codegen-350M-mono": 2048,
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
@lru_cache()
|
60 |
+
def bytes_to_unicode():
|
61 |
+
"""
|
62 |
+
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
|
63 |
+
characters the bpe code barfs on.
|
64 |
+
|
65 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
66 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
|
67 |
+
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
|
68 |
+
tables between utf-8 bytes and unicode strings.
|
69 |
+
"""
|
70 |
+
bs = (
|
71 |
+
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
72 |
+
)
|
73 |
+
cs = bs[:]
|
74 |
+
n = 0
|
75 |
+
for b in range(2**8):
|
76 |
+
if b not in bs:
|
77 |
+
bs.append(b)
|
78 |
+
cs.append(2**8 + n)
|
79 |
+
n += 1
|
80 |
+
cs = [chr(n) for n in cs]
|
81 |
+
return dict(zip(bs, cs))
|
82 |
+
|
83 |
+
|
84 |
+
def get_pairs(word):
|
85 |
+
"""
|
86 |
+
Return set of symbol pairs in a word.
|
87 |
+
|
88 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
89 |
+
"""
|
90 |
+
pairs = set()
|
91 |
+
prev_char = word[0]
|
92 |
+
for char in word[1:]:
|
93 |
+
pairs.add((prev_char, char))
|
94 |
+
prev_char = char
|
95 |
+
return pairs
|
96 |
+
|
97 |
+
|
98 |
+
class CodeGenTokenizer(PreTrainedTokenizer):
|
99 |
+
"""
|
100 |
+
Construct a CodeGen tokenizer. Based on byte-level Byte-Pair-Encoding.
|
101 |
+
|
102 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
103 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
104 |
+
|
105 |
+
```python
|
106 |
+
>>> from transformers import CodeGenTokenizer
|
107 |
+
|
108 |
+
>>> tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
109 |
+
>>> tokenizer("Hello world")["input_ids"]
|
110 |
+
[15496, 995]
|
111 |
+
|
112 |
+
>>> tokenizer(" Hello world")["input_ids"]
|
113 |
+
[18435, 995]
|
114 |
+
```
|
115 |
+
|
116 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
117 |
+
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
118 |
+
|
119 |
+
<Tip>
|
120 |
+
|
121 |
+
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
|
122 |
+
|
123 |
+
</Tip>
|
124 |
+
|
125 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
126 |
+
this superclass for more information regarding those methods.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
vocab_file (`str`):
|
130 |
+
Path to the vocabulary file.
|
131 |
+
merges_file (`str`):
|
132 |
+
Path to the merges file.
|
133 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
134 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
135 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
136 |
+
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
137 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
138 |
+
token instead.
|
139 |
+
bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
140 |
+
The beginning of sequence token.
|
141 |
+
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
142 |
+
The end of sequence token.
|
143 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
144 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
145 |
+
other word. (CodeGen tokenizer detect beginning of words by the preceding space).
|
146 |
+
"""
|
147 |
+
|
148 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
149 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
150 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
151 |
+
model_input_names = ["input_ids", "attention_mask"]
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self,
|
155 |
+
vocab_file,
|
156 |
+
merges_file,
|
157 |
+
errors="replace",
|
158 |
+
unk_token="<|endoftext|>",
|
159 |
+
bos_token="<|endoftext|>",
|
160 |
+
eos_token="<|endoftext|>",
|
161 |
+
pad_token=None,
|
162 |
+
add_prefix_space=False,
|
163 |
+
add_bos_token=False,
|
164 |
+
**kwargs,
|
165 |
+
):
|
166 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
167 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
168 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
169 |
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
170 |
+
super().__init__(
|
171 |
+
errors=errors,
|
172 |
+
unk_token=unk_token,
|
173 |
+
bos_token=bos_token,
|
174 |
+
eos_token=eos_token,
|
175 |
+
pad_token=pad_token,
|
176 |
+
add_prefix_space=add_prefix_space,
|
177 |
+
add_bos_token=add_bos_token,
|
178 |
+
**kwargs,
|
179 |
+
)
|
180 |
+
self.add_bos_token = add_bos_token
|
181 |
+
|
182 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
183 |
+
self.encoder = json.load(vocab_handle)
|
184 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
185 |
+
self.errors = errors # how to handle errors in decoding
|
186 |
+
self.byte_encoder = bytes_to_unicode()
|
187 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
188 |
+
with open(merges_file, encoding="utf-8") as merges_handle:
|
189 |
+
bpe_merges = merges_handle.read().split("\n")[1:-1]
|
190 |
+
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
191 |
+
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
192 |
+
self.cache = {}
|
193 |
+
self.add_prefix_space = add_prefix_space
|
194 |
+
|
195 |
+
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
196 |
+
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
197 |
+
|
198 |
+
@property
|
199 |
+
def vocab_size(self):
|
200 |
+
return len(self.encoder)
|
201 |
+
|
202 |
+
def get_vocab(self):
|
203 |
+
return dict(self.encoder, **self.added_tokens_encoder)
|
204 |
+
|
205 |
+
def bpe(self, token):
|
206 |
+
if token in self.cache:
|
207 |
+
return self.cache[token]
|
208 |
+
word = tuple(token)
|
209 |
+
pairs = get_pairs(word)
|
210 |
+
|
211 |
+
if not pairs:
|
212 |
+
return token
|
213 |
+
|
214 |
+
while True:
|
215 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
216 |
+
if bigram not in self.bpe_ranks:
|
217 |
+
break
|
218 |
+
first, second = bigram
|
219 |
+
new_word = []
|
220 |
+
i = 0
|
221 |
+
while i < len(word):
|
222 |
+
try:
|
223 |
+
j = word.index(first, i)
|
224 |
+
except ValueError:
|
225 |
+
new_word.extend(word[i:])
|
226 |
+
break
|
227 |
+
else:
|
228 |
+
new_word.extend(word[i:j])
|
229 |
+
i = j
|
230 |
+
|
231 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
232 |
+
new_word.append(first + second)
|
233 |
+
i += 2
|
234 |
+
else:
|
235 |
+
new_word.append(word[i])
|
236 |
+
i += 1
|
237 |
+
new_word = tuple(new_word)
|
238 |
+
word = new_word
|
239 |
+
if len(word) == 1:
|
240 |
+
break
|
241 |
+
else:
|
242 |
+
pairs = get_pairs(word)
|
243 |
+
word = " ".join(word)
|
244 |
+
self.cache[token] = word
|
245 |
+
return word
|
246 |
+
|
247 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
248 |
+
if self.add_bos_token:
|
249 |
+
bos_token_ids = [self.bos_token_id]
|
250 |
+
else:
|
251 |
+
bos_token_ids = []
|
252 |
+
|
253 |
+
output = bos_token_ids + token_ids_0
|
254 |
+
|
255 |
+
if token_ids_1 is None:
|
256 |
+
return output
|
257 |
+
|
258 |
+
return output + bos_token_ids + token_ids_1
|
259 |
+
|
260 |
+
def _tokenize(self, text):
|
261 |
+
"""Tokenize a string."""
|
262 |
+
bpe_tokens = []
|
263 |
+
for token in re.findall(self.pat, text):
|
264 |
+
token = "".join(
|
265 |
+
self.byte_encoder[b] for b in token.encode("utf-8")
|
266 |
+
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
|
267 |
+
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
|
268 |
+
return bpe_tokens
|
269 |
+
|
270 |
+
def _convert_token_to_id(self, token):
|
271 |
+
"""Converts a token (str) in an id using the vocab."""
|
272 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
273 |
+
|
274 |
+
def _convert_id_to_token(self, index):
|
275 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
276 |
+
return self.decoder.get(index)
|
277 |
+
|
278 |
+
def convert_tokens_to_string(self, tokens):
|
279 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
280 |
+
text = "".join(tokens)
|
281 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
|
282 |
+
return text
|
283 |
+
|
284 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
285 |
+
if not os.path.isdir(save_directory):
|
286 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
287 |
+
return
|
288 |
+
vocab_file = os.path.join(
|
289 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
290 |
+
)
|
291 |
+
merge_file = os.path.join(
|
292 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
|
293 |
+
)
|
294 |
+
|
295 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
296 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
297 |
+
|
298 |
+
index = 0
|
299 |
+
with open(merge_file, "w", encoding="utf-8") as writer:
|
300 |
+
writer.write("#version: 0.2\n")
|
301 |
+
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
302 |
+
if index != token_index:
|
303 |
+
logger.warning(
|
304 |
+
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
|
305 |
+
" Please check that the tokenizer is not corrupted!"
|
306 |
+
)
|
307 |
+
index = token_index
|
308 |
+
writer.write(" ".join(bpe_tokens) + "\n")
|
309 |
+
index += 1
|
310 |
+
|
311 |
+
return vocab_file, merge_file
|
312 |
+
|
313 |
+
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
|
314 |
+
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
|
315 |
+
if is_split_into_words or add_prefix_space:
|
316 |
+
text = " " + text
|
317 |
+
return (text, kwargs)
|
318 |
+
|
319 |
+
def decode(
|
320 |
+
self,
|
321 |
+
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
|
322 |
+
skip_special_tokens: bool = False,
|
323 |
+
clean_up_tokenization_spaces: bool = None,
|
324 |
+
truncate_before_pattern: Optional[List[str]] = None,
|
325 |
+
**kwargs,
|
326 |
+
) -> str:
|
327 |
+
"""
|
328 |
+
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
|
329 |
+
tokens and clean up tokenization spaces.
|
330 |
+
|
331 |
+
Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
|
335 |
+
List of tokenized input ids. Can be obtained using the `__call__` method.
|
336 |
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
337 |
+
Whether or not to remove special tokens in the decoding.
|
338 |
+
clean_up_tokenization_spaces (`bool`, *optional*):
|
339 |
+
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
340 |
+
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
341 |
+
truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
|
342 |
+
A list of regular expression strings that will be used to truncate the returned string. This can be
|
343 |
+
used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
|
344 |
+
of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
|
345 |
+
kwargs (additional keyword arguments, *optional*):
|
346 |
+
Will be passed to the underlying model specific decode method.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
`str`: The decoded sentence.
|
350 |
+
"""
|
351 |
+
decoded_text = super()._decode(
|
352 |
+
token_ids=token_ids,
|
353 |
+
skip_special_tokens=skip_special_tokens,
|
354 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
355 |
+
**kwargs,
|
356 |
+
)
|
357 |
+
|
358 |
+
if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
|
359 |
+
decoded_text = self.truncate(decoded_text, truncate_before_pattern)
|
360 |
+
|
361 |
+
return decoded_text
|
362 |
+
|
363 |
+
def truncate(self, completion, truncate_before_pattern):
|
364 |
+
def find_re(string, pattern, start_pos):
|
365 |
+
m = pattern.search(string, start_pos)
|
366 |
+
return m.start() if m else -1
|
367 |
+
|
368 |
+
terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
|
369 |
+
|
370 |
+
prints = list(re.finditer("^print", completion, re.MULTILINE))
|
371 |
+
|
372 |
+
if len(prints) > 1:
|
373 |
+
completion = completion[: prints[1].start()]
|
374 |
+
|
375 |
+
defs = list(re.finditer("^def", completion, re.MULTILINE))
|
376 |
+
|
377 |
+
if len(defs) > 1:
|
378 |
+
completion = completion[: defs[1].start()]
|
379 |
+
|
380 |
+
start_pos = 0
|
381 |
+
|
382 |
+
terminals_pos = [
|
383 |
+
pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
|
384 |
+
]
|
385 |
+
|
386 |
+
if len(terminals_pos) > 0:
|
387 |
+
return completion[: min(terminals_pos)]
|
388 |
+
else:
|
389 |
+
return completion
|