from __future__ import annotations import torch import torch.nn.functional as F from transformers import LlamaModel, LlamaConfig, PreTrainedTokenizer from transformers.modeling_attn_mask_utils import AttentionMaskConverter class DramaModel(LlamaModel): """ DramaModel is a modified version of the LlamaModel that supports bi-directional attention and provides query and document encoding functionalities. """ def __init__(self, config: LlamaConfig): """ Initializes the DramaModel by disabling causal masking in self-attention layers. """ super().__init__(config) for layer in self.layers: layer.self_attn.is_causal = False # query prefix self.query_prefix = "Query: " self.max_seq_len = 8192 self.hidden_size = config.hidden_size def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_seen_tokens=None, output_attentions=False, ): """ Updates the causal mask for attention computations. """ if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None if attention_mask is None or attention_mask.dim() == 4: return attention_mask return AttentionMaskConverter._expand_mask( mask=attention_mask, dtype=input_tensor.dtype, ) def _average_pool( self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """ Computes the average pooled representation of the last hidden states. """ last_hidden = last_hidden_states.masked_fill( ~attention_mask[..., None].bool(), 0.0 ) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] def _tokenize( self, tokenizer: PreTrainedTokenizer, texts: list[str], max_seq_len: int = None, ): """ Tokenizes input text sequences with optional sequence length restriction. """ if max_seq_len is None: max_seq_len = self.max_seq_len tokenized = tokenizer( texts, padding=False, truncation=True, max_length=max_seq_len - 1, return_attention_mask=False, return_token_type_ids=False, add_special_tokens=True ) tokenized['input_ids'] = [ t + [tokenizer.eos_token_id] for t in tokenized['input_ids'] ] tokenized = tokenizer.pad( tokenized, padding=True, return_attention_mask=True, return_tensors='pt', ).to(self.device) return tokenized def forward(self, input_ids, attention_mask, dim, *args, **kwargs): """ Forward pass through the model. Args: input_ids (torch.Tensor): Input token IDs. attention_mask (torch.Tensor): Attention mask tensor. dim (int): Dimensionality for output embeddings. Returns: torch.Tensor: Normalized output embeddings. """ outputs = super().forward( input_ids, attention_mask, *args, **kwargs ) embeddings = self._average_pool( outputs.last_hidden_state[:, :, :dim], attention_mask ) # normalize embeddings embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings def encode_queries( self, tokenizer: PreTrainedTokenizer, queries: list[str], max_seq_len: int = None, dim: int = None, ): """ Encodes a list of queries into embeddings. Args: tokenizer (PreTrainedTokenizer): Tokenizer for text processing. queries (list[str]): List of query texts. max_seq_len (int, optional): Maximum sequence length. dim (int, optional): Dimensionality for output embeddings. Returns: torch.Tensor: Encoded query embeddings in shape (num_queries, dim). """ if not queries: raise ValueError("queries must not be empty.") if not isinstance(queries, list) or not all(isinstance(q, str) for q in queries): raise ValueError("queries must be a list of strings.") if tokenizer is None: raise ValueError("tokenizer must not be None.") if dim is not None and (dim < 1 or dim > self.hidden_size): raise ValueError(f"dim must be in range [1, {self.hidden_size}].") queries = [self.query_prefix + query for query in queries] tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len) embeddings = self(**tokenized_queries, dim=dim) return embeddings def encode_documents( self, tokenizer: PreTrainedTokenizer, documents: list[str], max_seq_len: int = None, dim: int = None, ): """ Encodes a list of documents into embeddings. Args: tokenizer (PreTrainedTokenizer): Tokenizer for text processing. documents (list[str]): List of document texts. max_seq_len (int, optional): Maximum sequence length. dim (int, optional): Dimensionality for output embeddings. Returns: torch.Tensor: Encoded document embeddings in shape (num_documents, dim). """ if not documents: raise ValueError("documents must not be empty.") if not isinstance(documents, list) or not all(isinstance(d, str) for d in documents): raise ValueError("documents must be a list of strings.") if tokenizer is None: raise ValueError("tokenizer must not be None.") if dim is not None and (dim < 1 or dim > self.hidden_size): raise ValueError(f"dim must be in range [1, {self.hidden_size}].") tokenized_documents = self._tokenize(tokenizer, documents, max_seq_len) embeddings = self(**tokenized_documents, dim=dim) return embeddings