import torch import torch.nn.functional as F def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 200, sample: bool = True, top_k: int = 40): model.eval() input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda") for token_n in range(n_tokens_to_gen): with torch.no_grad(): indices_to_input = input_ids next_token_logits = mamba_model(indices_to_input)[:, -1] probs = F.softmax(next_token_logits, dim=-1) (batch, vocab_size) = probs.shape if top_k is not None: (values, indices) = torch.topk(probs, k=top_k) probs[probs < values[:, -1, None]] = 0 probs = probs / probs.sum(axis=1, keepdims=True) if sample: next_indices = torch.multinomial(probs, num_samples=1) else: next_indices = torch.argmax(probs, dim=-1)[:, None] input_ids = torch.cat([input_ids, next_indices], dim=1) output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0] return output_completions