--- license: mit datasets: - AdamLucek/apple-environmental-report-QA-retrieval base_model: sentence-transformers/all-MiniLM-L6-v2 pipeline_tag: feature-extraction library_name: peft --- # all-MiniLM-L6-v2-query-only-linear-adapter-AppleQA Query-only linear adapter for [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) with the [AdamLucek/apple-environmental-report-QA-retrieval](https://huggingface.co/datasets/AdamLucek/apple-environmental-report-QA-retrieval) dataset. 6 adapters trained at 10, 20, 30, and 40 epochs with: - Triplet Margin Loss, Margin=1.0, Euclidean Distance=2 - AdamW Optimizer - Random negative sampling from irrelevant document - LR: 0.003 - Batch size: 32 - Grad Norm: 1.0 - Warmup Steps: 100 Training script and model creation available on [Github Repo](https://github.com/ALucek/linear-adapter-embedding) # Assessment **Baseline Hit Rate @10**: 61.860% **Baseline Reciprocal Rank @10**: 0.31108 (Average Rank 3.2) *Best performing checkpoint at 30epochs* **Average Hit Rate @10**: 66.628% **Mean Reciprocal Rank @10**: 0.33119 (Average Rank 3.0) A **7.7% Improvement** in hit rate and a **6.5% improvement** in mean reciprocal rank against base embedding model. # Usage ```python import torch from torch import nn from sentence_transformers import SentenceTransformer class LinearAdapter(nn.Module): def __init__(self, input_dim): super().__init__() self.linear = nn.Linear(input_dim, input_dim) def forward(self, x): return self.linear(x) # Load the base model base_model = SentenceTransformer('all-MiniLM-L6-v2') # Load Adapter adapter = LinearAdapter(base_model.get_sentence_embedding_dimension()) adapter.load_state_dict(torch.load('adapters/linear_adapter_30epochs.pth')) # Example function for encoding def encode_query(query, base_model, adapter): device = next(adapter.parameters()).device query_emb = base_model.encode(query, convert_to_tensor=True).to(device) adapted_query_emb = adapter(query_emb) return adapted_query_emb.cpu().detach().numpy() emb = encode_query("Hello", base_model, adapter) print(emb[:5]) ``` **output** ``` [-0.13122843 0.02912715 0.07466945 0.09387457 0.13010463] ```