--- pipeline_tag: sentence-similarity tags: - sentence-transformers - feature-extraction - sentence-similarity license: mit --- For more details please refer to our github repo: https://github.com/FlagOpen/FlagEmbedding # LLARA ([paper](https://arxiv.org/pdf/2312.15503)) In this project, we introduce LLaRA: - EBAE: Embedding-Based Auto-Encoding. - EBAR: Embedding-Based Auto-Regression. ## Usage ``` import torch from transformers import AutoModel, AutoTokenizer, LlamaModel def get_query_inputs(queries, tokenizer, max_length=512): prefix = '"' suffix = '", predict the following passage within eight words: ' prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids'] suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:] queries_inputs = [] for query in queries: inputs = tokenizer(query, return_tensors=None, max_length=max_length, truncation=True, add_special_tokens=False) inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids inputs['attention_mask'] = [1] * len(inputs['input_ids']) queries_inputs.append(inputs) return tokenizer.pad( queries_inputs, padding=True, max_length=max_length, pad_to_multiple_of=8, return_tensors='pt', ) def get_passage_inputs(passages, tokenizer, max_length=512): prefix = '"' suffix = '", summarize the above passage within eight words: ' prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids'] suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:] passages_inputs = [] for passage in passages: inputs = tokenizer(passage, return_tensors=None, max_length=max_length, truncation=True, add_special_tokens=False) inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids inputs['attention_mask'] = [1] * len(inputs['input_ids']) passages_inputs.append(inputs) return tokenizer.pad( passages_inputs, padding=True, max_length=max_length, pad_to_multiple_of=8, return_tensors='pt', ) # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained('BAAI/LLARA-pretrain') model = AutoModel.from_pretrained('BAAI/LLARA-pretrain') # Define query and passage inputs query = "What is llama?" title = "Llama" passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era." query_input = get_query_inputs([query], tokenizer) passage_input = get_passage_inputs([passage], tokenizer) with torch.no_grad(): # compute query embedding query_outputs = model(**query_input, return_dict=True, output_hidden_states=True) query_embedding = query_outputs.hidden_states[-1][:, -8:, :] query_embedding = torch.mean(query_embedding, dim=1) query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1) # compute passage embedding passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True) passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :] passage_embeddings = torch.mean(passage_embeddings, dim=1) passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1) # compute similarity score score = query_embedding @ passage_embeddings.T print(score) ``` ## Acknowledgement Thanks to the authors of open-sourced datasets, including MSMARCO, BEIR, etc. Thanks to the open-sourced libraries like [Pyserini](https://github.com/castorini/pyserini). ## Citation If you find this repository useful, please consider giving a star :star: and citation ``` @misc{li2023making, title={Making Large Language Models A Better Foundation For Dense Retrieval}, author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao}, year={2023}, eprint={2312.15503}, archivePrefix={arXiv}, primaryClass={cs.CL} } ```