Korean Reranker Training on Amazon SageMaker
νκ΅μ΄ Reranker κ°λ°μ μν νμΈνλ κ°μ΄λλ₯Ό μ μν©λλ€.
ko-rerankerλ BAAI/bge-reranker-larger κΈ°λ° νκ΅μ΄ λ°μ΄ν°μ λν fine-tuned model μ
λλ€.
λ³΄λ€ μμΈν μ¬νμ korean-reranker-git / AWS Blog, νκ΅μ΄ Rerankerλ₯Ό νμ©ν κ²μ μ¦κ° μμ±(RAG) μ±λ₯ μ¬λ¦¬κΈ°μ μ°Έκ³ νμΈμ
0. Features
Rerankerλ μλ² λ© λͺ¨λΈκ³Ό λ¬λ¦¬ μ§λ¬Έκ³Ό λ¬Έμλ₯Ό μ λ ₯μΌλ‘ μ¬μ©νλ©° μλ² λ© λμ μ μ¬λλ₯Ό μ§μ μΆλ ₯ν©λλ€.
Rerankerμ μ§λ¬Έκ³Ό ꡬμ μ μ λ ₯νλ©΄ μ°κ΄μ± μ μλ₯Ό μ»μ μ μμ΅λλ€.
Rerankerλ CrossEntropy lossλ₯Ό κΈ°λ°μΌλ‘ μ΅μ νλλ―λ‘ κ΄λ ¨μ± μ μκ° νΉμ λ²μμ κ΅νλμ§ μμ΅λλ€.
1.Usage
- using Transformers
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["λλ λλ₯Ό μ«μ΄ν΄", "λλ λλ₯Ό μ¬λν΄"], \
["λλ λλ₯Ό μ’μν΄", "λμ λν λμ κ°μ μ μ¬λ μΌ μλ μμ΄"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
- using SageMaker
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1, # number of instances
instance_type='ml.g5.large' # ec2 instance type
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "λλ λλ₯Ό μ«μ΄ν΄", "text_pair": "λλ λλ₯Ό μ¬λν΄"},
{"text": "λλ λλ₯Ό μ’μν΄", "text_pair": "λμ λν λμ κ°μ μ μ¬λ μΌ μλ μμ΄"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')
2. Backgound
컨νμ€νΈ μμκ° μ νλμ μν₯ μ€λ€(Lost in Middle, Liu et al., 2023)
Reranker μ¬μ©ν΄μΌ νλ μ΄μ
- νμ¬ LLMμ context λ§μ΄ λ£λλ€κ³ μ’μκ±° μλ, relevantνκ² μμμ μμ΄μΌ μ λ΅μ μ λ§ν΄μ€λ€
- Semantic searchμμ μ¬μ©νλ similarity(relevant) scoreκ° μ κ΅νμ§ μλ€. (μ¦, μμ λ컀면 νμ λμ»€λ³΄λ€ νμ λ μ§λ¬Έμ μ μ¬ν μ λ³΄κ° λ§μ?)
- Embeddingμ meaning behind documentλ₯Ό κ°μ§λ κ²μ νΉνλμ΄ μλ€.
- μ§λ¬Έκ³Ό μ λ΅μ΄ μλ―Έμ κ°μ건 μλλ€. (Hypothetical Document Embeddings)
- ANNs(Approximate Nearest Neighbors) μ¬μ©μ λ°λ₯Έ ν¨λν°
3. Reranker models
[Cohere] Reranker
[BAAI] bge-reranker-large
[BAAI] bge-reranker-base
4. Dataset
msmarco-triplets
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
- ν΄λΉ λ°μ΄ν° μ μ μλ¬ΈμΌλ‘ ꡬμ±λμ΄ μμ΅λλ€.
- Amazon Translate κΈ°λ°μΌλ‘ λ²μνμ¬ νμ©νμμ΅λλ€.
Format
{"query": str, "pos": List[str], "neg": List[str]}
Queryλ μ§λ¬Έμ΄κ³ , posλ κΈμ ν μ€νΈ λͺ©λ‘, negλ λΆμ ν μ€νΈ λͺ©λ‘μ λλ€. 쿼리μ λν λΆμ ν μ€νΈκ° μλ κ²½μ° μ 체 λ§λμΉμμ μΌλΆλ₯Ό 무μμλ‘ μΆμΆνμ¬ λΆμ ν μ€νΈλ‘ μ¬μ©ν μ μμ΅λλ€.
Example
{"query": "λνλ―Όκ΅μ μλλ?", "pos": ["λ―Έκ΅μ μλλ μμ±ν΄μ΄κ³ , μΌλ³Έμ λμΏμ΄λ©° νκ΅μ μμΈμ΄λ€."], "neg": ["λ―Έκ΅μ μλλ μμ±ν΄μ΄κ³ , μΌλ³Έμ λμΏμ΄λ©° λΆνμ νμμ΄λ€."]}
5. Performance
Model | has-right-in-contexts | mrr (mean reciprocal rank) |
---|---|---|
without-reranker (default) | 0.93 | 0.80 |
with-reranker (bge-reranker-large) | 0.95 | 0.84 |
with-reranker (fine-tuned using korean) | 0.96 | 0.87 |
- evaluation set:
./dataset/evaluation/eval_dataset.csv
- training parameters:
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01,
}
6. Acknowledgement
- Part of the code is developed based on FlagEmbedding and KoSimCSE-SageMaker.
7. Citation
- If you find this repository useful, please consider giving a like β and citation
8. Contributors:
9. License
- FlagEmbedding is licensed under the MIT License.
10. Analytics
- Downloads last month
- 8,737
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.