amigov1 / medalpaca /handler.py
asach's picture
Upload folder using huggingface_hub
d727a17
import json
import logging
from typing import Dict, Optional
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger(__name__)
def load_json(fn: str):
with open(fn, "r") as fp:
d = json.load(fp)
return d
class DataHandler:
"""Helper class to handle prompt generation and data tokenization.
Args:
tokenizer: The tokenizer to use for tokenization.
prompt_template (str, optio
nal):
The path to the JSON file containing the prompt template.
Defaults to "/home/ubuntu/LLM/.conda/om/medAlpaca/medalpaca/prompts/medalpaca.json".
model_max_length (int, optional):
The maximum length of the tokenized sequence.
Should not exceed 2048, as LLaMA is trained with this. Defaults to 256.
train_on_inputs (bool, optional):
If False, masks out inputs in loss. Defaults to True.
Methods:
tokenize(prompt: str, add_eos_token: bool = True) -> Dict:
Tokenizes the given prompt and optionally adds an end-of-sequence (EOS) token.
generate_and_tokenize_prompt(data_point: Dict) -> Dict:
Generates a prompt based on the given data point and tokenizes it.
"""
def __init__(
self,
tokenizer,
prompt_template: str = "prompt_templates/medalpaca.json",
model_max_length: int = 256,
train_on_inputs: bool = True,
) -> None:
if model_max_length > 2048:
logger.warn(f"{model_max_length} exceeds the max token length LLaMA was trained with.")
self.prompt_template = load_json(prompt_template)
self.model_max_length = model_max_length
self.train_on_inputs = train_on_inputs
self.tokenizer = tokenizer
def tokenize(self, prompt: str, add_eos_token: bool = True, return_tensors: str = None, truncation: bool = True) -> Dict[str, list]:
"""
Tokenize the given prompt and optionally add an end-of-sequence (EOS) token.
This function tokenizes the input prompt without adding special tokens by default.
If the `add_eos_token` parameter is True and the tokenized sequence doesn't already
end with an EOS token, an EOS token will be added to the end of the sequence.
Args:
prompt (str): The text to be tokenized.
add_eos_token (bool, optional): Whether to add an EOS token at the end of
the tokenized sequence. Defaults to True.
return_tensors (str, optional): If tensors should be returned (and what type).
trunctaion (bool, optional); Whether to truncate the input to max_model_length
Returns:
Dict: A dictionary containing the tokenized data:
- input_ids: The tokenized input IDs of the prompt.
- attention_mask: The attention mask for the tokenized input IDs.
- labels: The labels for the tokenized input IDs (identical to input_ids).
"""
# TODO: optimize (roll back changes from debugging)
result: Dict = self.tokenizer(
prompt,
truncation=truncation,
max_length=self.model_max_length,
padding=False,
return_tensors=return_tensors,
add_special_tokens=False,
)
if (
result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.model_max_length
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(self, data_point: Dict):
"""
Generate a prompt based on the given data point and tokenize it.
This function creates a prompt using the given data point, which consists
of an instruction, input, and output. It then tokenizes the generated prompt
and returns the tokenized representation. If the `train_on_inputs` global
variable is False, the function will create a user prompt without the
expected output and only tokenize that part, masking the output part in the
"labels" field with -100.
Args:
data_point (Dict): A dictionary containing the following keys:
- instruction: The instruction text for the prompt.
- input: The input text for the prompt.
- output: The output text for the prompt.
Returns:
Dict: A dictionary containing the tokenized prompt and associated data:
- input_ids: The tokenized input IDs of the generated prompt.
- attention_mask: The attention mask for the tokenized input IDs.
- labels: The labels to be used during model training, with the output
part unmasked and the rest masked with -100 if `train_on_inputs` is False.
"""
prompt: str = self.generate_prompt(
instruction=data_point.get("instruction", ""),
input=data_point.get("input", ""),
output=data_point.get("output", ""),
)
tokenized_prompt: Dict = self.tokenize(prompt)
if not self.train_on_inputs:
user_prompt: str = self.generate_prompt(
instruction=data_point.get("instruction", ""), input=data_point.get("input", "")
)
tokenized_user_prompt: Dict = self.tokenize(user_prompt, add_eos_token=False)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
# mask out the inputs
tokenized_prompt["labels"] = [
-100 if i < user_prompt_len else label
for i, label in enumerate(tokenized_prompt["labels"])
]
return tokenized_prompt
def generate_prompt(
self,
instruction: Optional[str] = None,
input: Optional[str] = None,
output: Optional[str] = None,
) -> str:
"""
Generates a prompt for the given instruction, input and output using the specified prompt
template.
Args:
instruction (Optional[str]):
An optional string representing the instruction to be included in the prompt.
input (Optional[str]):
An optional string representing the input to be included in the prompt.
output (Optional[str]):
An optional string representing the output to be included in the prompt.
Returns:
str: The prompt string created using the specified prompt template.
Raises:
ValueError: If none of `instruction`, `input`, and `output` is defined.
## Example
using ``
{
"instruction":
},
data_handler = DataHandler(tokenizer, "prompt_templates/medalpaca.json")
prompt = data_hanlder.generate_prompt(
instruction = "Provide a short answer to this medical question.",
input = "What to expect if I have Aortic coarctation (Outlook/Prognosis)?",
output = (
"The prognosis of aortic coarctation depends on whether balloon "
"angioplasty and stenting or the surgery has been done or not."
)
)
print(prompt)
>>> Below is an instruction that describes a task, paired with an input that provides
further context. Write a response that appropriately completes the request.
### Instruction:
Provide a short answer to this medical question.
### Input:
What to expect if I have Aortic coarctation (Outlook/Prognosis)?
### Response:
The prognosis of aortic coarctation depends on whether balloon angioplasty and
stenting or the surgery has been done or not.
"""
if not any([instruction, input, output]):
raise ValueError("At least one of `instruction`, `input`, `output` should be defined")
prompt = (
f'{self.prompt_template["primer"]}'
f'{self.prompt_template["instruction"]}{instruction or ""}'
f'{self.prompt_template["input"]}{input or ""}'
f'{self.prompt_template["output"]}{output or ""}'
)
return prompt
def resolve_output(self, output: str):
pass