Korean Reranker Training on Amazon SageMaker

ν•œκ΅­μ–΄ Reranker κ°œλ°œμ„ μœ„ν•œ νŒŒμΈνŠœλ‹ κ°€μ΄λ“œλ₯Ό μ œμ‹œν•©λ‹ˆλ‹€.

ko-rerankerλŠ” BAAI/bge-reranker-larger 기반 ν•œκ΅­μ–΄ 데이터에 λŒ€ν•œ fine-tuned model μž…λ‹ˆλ‹€.
보닀 μžμ„Έν•œ 사항은 korean-reranker-git / AWS Blog, ν•œκ΅­μ–΄ Rerankerλ₯Ό ν™œμš©ν•œ 검색 증강 생성(RAG) μ„±λŠ₯ μ˜¬λ¦¬κΈ°μ„ μ°Έκ³ ν•˜μ„Έμš”


0. Features

  • RerankerλŠ” μž„λ² λ”© λͺ¨λΈκ³Ό 달리 질문과 λ¬Έμ„œλ₯Ό μž…λ ₯으둜 μ‚¬μš©ν•˜λ©° μž„λ² λ”© λŒ€μ‹  μœ μ‚¬λ„λ₯Ό 직접 좜λ ₯ν•©λ‹ˆλ‹€.

  • Reranker에 질문과 κ΅¬μ ˆμ„ μž…λ ₯ν•˜λ©΄ μ—°κ΄€μ„± 점수λ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€.

  • RerankerλŠ” CrossEntropy lossλ₯Ό 기반으둜 μ΅œμ ν™”λ˜λ―€λ‘œ κ΄€λ ¨μ„± μ μˆ˜κ°€ νŠΉμ • λ²”μœ„μ— κ΅­ν•œλ˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.

1.Usage

  • using Transformers
    def exp_normalize(x):
      b = x.max()
      y = np.exp(x - b)
      return y / y.sum()
    
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.eval()

    pairs = [["λ‚˜λŠ” λ„ˆλ₯Ό μ‹«μ–΄ν•΄", "λ‚˜λŠ” λ„ˆλ₯Ό μ‚¬λž‘ν•΄"], \
             ["λ‚˜λŠ” λ„ˆλ₯Ό μ’‹μ•„ν•΄", "λ„ˆμ— λŒ€ν•œ λ‚˜μ˜ 감정은 μ‚¬λž‘ 일 μˆ˜λ„ μžˆμ–΄"]]

    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
        scores = exp_normalize(scores.numpy())
        print (f'first: {scores[0]}, second: {scores[1]}')
  • using SageMaker
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
    'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
    'HF_TASK':'text-classification'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.28.1',
    pytorch_version='2.0.0',
    py_version='py310',
    env=hub,
    role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.g5.large' # ec2 instance type
)

runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
    {
        "inputs": [
            {"text": "λ‚˜λŠ” λ„ˆλ₯Ό μ‹«μ–΄ν•΄", "text_pair": "λ‚˜λŠ” λ„ˆλ₯Ό μ‚¬λž‘ν•΄"},
            {"text": "λ‚˜λŠ” λ„ˆλ₯Ό μ’‹μ•„ν•΄", "text_pair": "λ„ˆμ— λŒ€ν•œ λ‚˜μ˜ 감정은 μ‚¬λž‘ 일 μˆ˜λ„ μžˆμ–΄"}
        ]
    }
)

response = runtime_client.invoke_endpoint(
    EndpointName="<endpoint-name>",
    ContentType="application/json",
    Accept="application/json",
    Body=payload
)

## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')

2. Backgound

  • μ»¨νƒμŠ€νŠΈ μˆœμ„œκ°€ 정확도에 영ν–₯ μ€€λ‹€(Lost in Middle, Liu et al., 2023)

  • Reranker μ‚¬μš©ν•΄μ•Ό ν•˜λŠ” 이유

    • ν˜„μž¬ LLM은 context 많이 λ„£λŠ”λ‹€κ³  쒋은거 μ•„λ‹˜, relevantν•œκ²Œ μƒμœ„μ— μžˆμ–΄μ•Ό 정닡을 잘 말해쀀닀
    • Semantic searchμ—μ„œ μ‚¬μš©ν•˜λŠ” similarity(relevant) scoreκ°€ μ •κ΅ν•˜μ§€ μ•Šλ‹€. (즉, μƒμœ„ 랭컀면 ν•˜μœ„ λž­μ»€λ³΄λ‹€ 항상 더 μ§ˆλ¬Έμ— μœ μ‚¬ν•œ 정보가 λ§žμ•„?)
      • Embedding은 meaning behind documentλ₯Ό κ°€μ§€λŠ” 것에 νŠΉν™”λ˜μ–΄ μžˆλ‹€.
      • 질문과 정닡이 μ˜λ―Έμƒ 같은건 μ•„λ‹ˆλ‹€. (Hypothetical Document Embeddings)
      • ANNs(Approximate Nearest Neighbors) μ‚¬μš©μ— λ”°λ₯Έ νŒ¨λ„ν‹°

3. Reranker models


4. Dataset

  • msmarco-triplets

    • (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
    • ν•΄λ‹Ή 데이터 셋은 영문으둜 κ΅¬μ„±λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€.
    • Amazon Translate 기반으둜 λ²ˆμ—­ν•˜μ—¬ ν™œμš©ν•˜μ˜€μŠ΅λ‹ˆλ‹€.
  • Format

{"query": str, "pos": List[str], "neg": List[str]}
  • QueryλŠ” 질문이고, posλŠ” 긍정 ν…μŠ€νŠΈ λͺ©λ‘, negλŠ” λΆ€μ • ν…μŠ€νŠΈ λͺ©λ‘μž…λ‹ˆλ‹€. 쿼리에 λŒ€ν•œ λΆ€μ • ν…μŠ€νŠΈκ°€ μ—†λŠ” 경우 전체 λ§λ­‰μΉ˜μ—μ„œ 일뢀λ₯Ό λ¬΄μž‘μœ„λ‘œ μΆ”μΆœν•˜μ—¬ λΆ€μ • ν…μŠ€νŠΈλ‘œ μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

  • Example

{"query": "λŒ€ν•œλ―Όκ΅­μ˜ μˆ˜λ„λŠ”?", "pos": ["미ꡭ의 μˆ˜λ„λŠ” μ›Œμ‹±ν„΄μ΄κ³ , 일본은 도쿄이며 ν•œκ΅­μ€ μ„œμšΈμ΄λ‹€."], "neg": ["미ꡭ의 μˆ˜λ„λŠ” μ›Œμ‹±ν„΄μ΄κ³ , 일본은 도쿄이며 λΆν•œμ€ 평양이닀."]}

5. Performance

Model has-right-in-contexts mrr (mean reciprocal rank)
without-reranker (default) 0.93 0.80
with-reranker (bge-reranker-large) 0.95 0.84
with-reranker (fine-tuned using korean) 0.96 0.87
  • evaluation set:
./dataset/evaluation/eval_dataset.csv
  • training parameters:
{
    "learning_rate": 5e-6,
    "fp16": True,
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 32,
    "train_group_size": 3,
    "max_len": 512,
    "weight_decay": 0.01,
}

6. Acknowledgement


7. Citation

  • If you find this repository useful, please consider giving a like ⭐ and citation

8. Contributors:

  • Dongjin Jang, Ph.D. (AWS AI/ML Specislist Solutions Architect) | Mail | Linkedin | Git |

9. License

10. Analytics

  • Hits
Downloads last month
8,737
Safetensors
Model size
560M params
Tensor type
F32
Β·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.