|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re
|
|
import warnings
|
|
|
|
import torch
|
|
from accelerate.utils import extract_model_from_parallel
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
|
from ..import_utils import is_rich_available
|
|
|
|
|
|
if is_rich_available():
|
|
from rich import print
|
|
from rich.text import Text
|
|
|
|
|
|
class StringStoppingCriteria(StoppingCriteria):
|
|
"""Custom `StoppingCriteria` which checks if all generations in the batch are completed."""
|
|
|
|
def __init__(self, stop_strings, tokenizer):
|
|
self.stop_strings = stop_strings
|
|
self.tokenizer = tokenizer
|
|
self.first_call = True
|
|
|
|
def __call__(self, input_ids, scores, **kwargs):
|
|
"""Returns true if all generated sequences contain any of the stop strings."""
|
|
if self.first_call:
|
|
self.generated_tokens = [1 for _ in range(input_ids.shape[0])]
|
|
self.start_length = input_ids.shape[-1] - 1
|
|
self.first_call = False
|
|
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
|
|
done = []
|
|
|
|
for i, decoded_generation in enumerate(decoded_generations):
|
|
sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
|
|
done.append(sequence_complete)
|
|
if not sequence_complete:
|
|
self.generated_tokens[i] += 1
|
|
|
|
if all(done):
|
|
self.first_call = True
|
|
|
|
return all(done)
|
|
|
|
|
|
class TextHistory:
|
|
"""The TextHistory class keeps track of the history of an interaction between the language model and the environment."""
|
|
|
|
def __init__(self, text, tokens, system=True):
|
|
"""
|
|
Initialize TextHistory.
|
|
|
|
args:
|
|
text (`str`): The text of the first segment.
|
|
tokens (`torch.LongTensor`): The tokens of the first segment.
|
|
system (`bool`, *optional*): Whether the first segment is a system or user segment.
|
|
"""
|
|
self.system_spans = []
|
|
self.text_spans = []
|
|
self.token_spans = []
|
|
self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device)
|
|
self.text = ""
|
|
self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device)
|
|
self.completed = False
|
|
self.truncated = False
|
|
self.reward = 0.0
|
|
|
|
self.prompt_color = "black on grey85"
|
|
self.system_color = "black on cyan3"
|
|
self.model_color = "black on deep_sky_blue1"
|
|
self.reward_color = "black on plum1"
|
|
|
|
self.append_segment(text, tokens, system=system)
|
|
|
|
def append_segment(self, text, tokens, system=True):
|
|
"""
|
|
Append a new segment to the history.
|
|
|
|
args:
|
|
text (`str`): The text of the new segment.
|
|
tokens (`torch.LongTensor`): The tokens of the new segment.
|
|
system (`bool`, *optional*): Whether the new segment is a system or user segment.
|
|
"""
|
|
|
|
if len(text) == 0 or len(tokens) == 0:
|
|
raise ValueError("Can't append empty text or token list to history.")
|
|
|
|
original_text_length = len(self.text)
|
|
|
|
self.text += text
|
|
self.text_spans.append((original_text_length, len(self.text)))
|
|
self.system_spans.append(system)
|
|
|
|
original_token_length = len(self.tokens)
|
|
|
|
self.tokens = torch.cat((self.tokens, tokens))
|
|
if system:
|
|
self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens)))
|
|
else:
|
|
self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens)))
|
|
self.token_spans.append((original_token_length, len(self.tokens)))
|
|
|
|
def complete(self, truncated=False):
|
|
"""
|
|
Mark the history as completed.
|
|
"""
|
|
self.completed = True
|
|
self.truncated = truncated
|
|
|
|
@property
|
|
def last_text_segment(self):
|
|
"""
|
|
Get the last text segment.
|
|
"""
|
|
start, end = self.text_spans[-1]
|
|
return self.text[start:end]
|
|
|
|
def split_query_response_tokens(self):
|
|
"""
|
|
Split the tokens into query and response tokens.
|
|
"""
|
|
split_index = self.token_spans[0][1]
|
|
query = self.tokens[:split_index]
|
|
response = self.tokens[split_index:]
|
|
mask = self.token_masks[split_index:]
|
|
|
|
return query, response, mask
|
|
|
|
def show_text(self, show_legend=False):
|
|
"""
|
|
Print the text history.
|
|
"""
|
|
if not is_rich_available():
|
|
warnings.warn("install rich to display text")
|
|
return
|
|
|
|
text = Text(self.text)
|
|
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
|
|
for i, (start, end) in enumerate(self.text_spans[1:]):
|
|
if self.system_spans[i + 1]:
|
|
text.stylize(self.system_color, start, end)
|
|
else:
|
|
text.stylize(self.model_color, start, end)
|
|
|
|
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
|
|
print(text)
|
|
|
|
if show_legend:
|
|
self.show_colour_legend()
|
|
|
|
def show_tokens(self, tokenizer, show_legend=False):
|
|
"""
|
|
Print the history tokens.
|
|
"""
|
|
if not is_rich_available():
|
|
warnings.warn("install rich to display tokens")
|
|
return
|
|
|
|
text = Text()
|
|
prompt_end = self.token_spans[0][1]
|
|
for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)):
|
|
if i < prompt_end:
|
|
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color)
|
|
text.append(" ")
|
|
elif mask == 0:
|
|
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color)
|
|
text.append(" ")
|
|
else:
|
|
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color)
|
|
text.append(" ")
|
|
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
|
|
print(text)
|
|
if show_legend:
|
|
self.show_colour_legend()
|
|
|
|
def show_colour_legend(self):
|
|
"""
|
|
Print the colour legend.
|
|
"""
|
|
if not is_rich_available():
|
|
warnings.warn("install rich to display colour legend")
|
|
return
|
|
text = Text("\n\n(Colour Legend: ")
|
|
text.append("Prompt", style=self.prompt_color)
|
|
text.append("|")
|
|
text.append("System", style=self.system_color)
|
|
text.append("|")
|
|
text.append("Model", style=self.model_color)
|
|
text.append("|")
|
|
text.append("Reward", style=self.reward_color)
|
|
text.append(")")
|
|
print(text)
|
|
|
|
|
|
class TextEnvironment:
|
|
"""
|
|
The TextEnvironment enables interaction of a LLM with an environment using tools.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model=None,
|
|
tokenizer=None,
|
|
tools=None,
|
|
reward_fn=None,
|
|
prompt=None,
|
|
max_turns=4,
|
|
max_tool_reponse=100,
|
|
max_length=None,
|
|
generation_kwargs=None,
|
|
):
|
|
"""
|
|
Initialize TextEnvironment.
|
|
|
|
Args:
|
|
model (`PreTrainedModelWrapper`): The model to use for generation.
|
|
tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation.
|
|
tools (list): A list of tools to use for interaction.
|
|
reward_fn (function): A function that takes a string and returns a reward.
|
|
prompt (str): The base prompt to use for generation. Is prepended to the tasks.
|
|
max_turns (Optional[int]): The maximum number of turns to allow.
|
|
max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response.
|
|
max_length (Optional[int]): The maximum number of tokens to allow in an episode.
|
|
generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
|
|
"""
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.prompt = prompt
|
|
if isinstance(tools, dict):
|
|
self.tools = tools
|
|
else:
|
|
self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
|
|
self.reward_fn = reward_fn
|
|
self.max_length = max_length
|
|
self.request_token = "<request>"
|
|
self.call_token = "<call>"
|
|
self.response_token = "<response>"
|
|
self.submit_token = "<submit>"
|
|
self.max_turns = max_turns
|
|
self.max_tool_response = max_tool_reponse
|
|
|
|
if generation_kwargs is None:
|
|
self.generation_kwargs = dict()
|
|
else:
|
|
self.generation_kwargs = generation_kwargs
|
|
|
|
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
|
|
self.current_device = extract_model_from_parallel(self.model).pretrained_model.device
|
|
|
|
def run(self, queries, **rewards_kwargs):
|
|
"""
|
|
Run the environment on a list of queries.
|
|
|
|
Args:
|
|
queries (list[str]): A list of queries to run the model in the environment on.
|
|
"""
|
|
turns = 0
|
|
|
|
queries = [self.prompt + task for task in queries]
|
|
queries_tokens = [self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) for query in queries]
|
|
|
|
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
|
|
|
|
while any([not history.completed for history in histories]) and turns < self.max_turns:
|
|
histories = self.generate(histories)
|
|
histories = self.tasks_end_check(histories)
|
|
|
|
for i in range(len(histories)):
|
|
histories[i] = self.step(histories[i])
|
|
histories = self.tasks_end_check(histories, model_turn=False)
|
|
turns += 1
|
|
self.compute_reward(histories, **rewards_kwargs)
|
|
|
|
|
|
queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories]))
|
|
|
|
rewards = [history.reward for history in histories]
|
|
return queries, responses, masks, rewards, histories
|
|
|
|
def step(self, history):
|
|
"""
|
|
Step the environment forward one turn.
|
|
|
|
Args:
|
|
history (`TextHistory`): The history to step forward.
|
|
"""
|
|
truncated, ended = self.task_end_check(history)
|
|
if ended:
|
|
history.complete(truncated=truncated)
|
|
if history.completed:
|
|
return history
|
|
|
|
tool, query = self.parse_tool_call(history.last_text_segment)
|
|
if tool is None or query is None:
|
|
response = f"Unknown tool call: {history.last_text_segment}"
|
|
else:
|
|
if tool not in self.tools:
|
|
response = f"Unknown tool {tool}."
|
|
try:
|
|
response = self.tools[tool](query)
|
|
except Exception as error:
|
|
response = f"Tool error: {str(error)}"
|
|
|
|
if len(response) > self.max_tool_response:
|
|
response = response[: (self.max_tool_response - 3)] + "..."
|
|
|
|
history.append_segment(
|
|
response + self.response_token,
|
|
self.tokenizer(response + self.response_token, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device),
|
|
system=True,
|
|
)
|
|
|
|
return history
|
|
|
|
def parse_tool_call(self, text):
|
|
"""
|
|
Parse request string. Expected format: <request><tool_name>query<call>
|
|
"""
|
|
result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL)
|
|
|
|
|
|
if result is None:
|
|
return None, None
|
|
else:
|
|
extracted_text = result.group()
|
|
|
|
result = re.search(r"<(.*?)>", extracted_text)
|
|
|
|
|
|
if result is None:
|
|
return None, None
|
|
else:
|
|
tool = result.group(1)
|
|
|
|
|
|
query = ">".join(extracted_text.split(">")[1:])
|
|
|
|
return tool, query
|
|
|
|
def compute_reward(self, histories, **reward_kwargs):
|
|
"""
|
|
Compute the reward for a list of histories.
|
|
"""
|
|
rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs)
|
|
for history, reward in zip(histories, rewards):
|
|
history.reward = reward
|
|
return histories
|
|
|
|
def generate(self, histories):
|
|
"""
|
|
Generate responses for a list of histories.
|
|
"""
|
|
active_histories = [i for i, history in enumerate(histories) if not history.completed]
|
|
|
|
query_tensors = [histories[i].tokens for i in active_histories]
|
|
response_tensors = self._generate_batched(query_tensors)
|
|
response_texts = self.tokenizer.batch_decode(response_tensors)
|
|
|
|
for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors):
|
|
histories[i].append_segment(response_text, response_tensor, system=False)
|
|
|
|
return histories
|
|
|
|
def tasks_end_check(self, histories, model_turn=True):
|
|
"""
|
|
Check if the current generation sequences have finished.
|
|
"""
|
|
for history in histories:
|
|
if not history.completed:
|
|
truncated, ended = self.task_end_check(history, model_turn=model_turn)
|
|
if ended:
|
|
history.complete(truncated=truncated)
|
|
return histories
|
|
|
|
def task_end_check(self, history, model_turn=True):
|
|
"""
|
|
Check if the current generation sequence has finished.
|
|
"""
|
|
truncated = False
|
|
ended = False
|
|
if history.completed:
|
|
return truncated, ended
|
|
if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length:
|
|
truncated = True
|
|
ended = True
|
|
elif self.tokenizer.eos_token in history.text:
|
|
ended = True
|
|
elif model_turn and not ((self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) or self.submit_token in history.last_text_segment):
|
|
ended = True
|
|
elif self.submit_token in history.last_text_segment:
|
|
ended = True
|
|
return truncated, ended
|
|
|
|
def _generate_batched(
|
|
self,
|
|
query_tensors,
|
|
batch_size: int = 16,
|
|
pad_to_multiple_of: int = None,
|
|
):
|
|
"""
|
|
Generate responses for a list of query tensors.
|
|
|
|
args:
|
|
query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for.
|
|
batch_size (int): The batch size to use for generation.
|
|
pad_to_multiple_of (int): The padding length to use for generation.
|
|
"""
|
|
outputs = []
|
|
padding_side_default = self.tokenizer.padding_side
|
|
if not self.is_encoder_decoder:
|
|
self.tokenizer.padding_side = "left"
|
|
|
|
|
|
batch_size = min(len(query_tensors), batch_size)
|
|
|
|
for i in range(0, len(query_tensors), batch_size):
|
|
|
|
end_index = min(len(query_tensors), i + batch_size)
|
|
|
|
batch = query_tensors[i:end_index]
|
|
batch_mask = [torch.ones_like(element) for element in batch]
|
|
inputs = {"input_ids": batch, "attention_mask": batch_mask}
|
|
|
|
padded_inputs = self.tokenizer.pad(
|
|
inputs,
|
|
padding=True,
|
|
max_length=None,
|
|
pad_to_multiple_of=pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
).to(self.current_device)
|
|
|
|
stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer)
|
|
|
|
self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria])
|
|
|
|
generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs)
|
|
|
|
for generation, mask, generated_tokens in zip(generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens):
|
|
if not self.is_encoder_decoder:
|
|
output = generation[(1 - mask).sum() :]
|
|
else:
|
|
output = generation
|
|
|
|
if not self.is_encoder_decoder:
|
|
output = output[(mask).sum() :]
|
|
|
|
|
|
outputs.append(output[:generated_tokens])
|
|
self.tokenizer.padding_side = padding_side_default
|
|
return outputs
|
|
|