ad_categorizer / train.py
win2win's picture
Create train.py
784bcca verified
raw
history blame contribute delete
835 Bytes
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
import json
import numpy as np
# 1. Load data
with open('data/listings.json') as f:
train_data = json.load(f)
# 2. Prepare examples
train_examples = []
for item in train_data:
train_examples.append(InputExample(
texts=[item['text']],
label=item['category_id']
))
# 3. Initialize model
model = SentenceTransformer('all-MiniLM-L6-v2')
# 4. Train with contrastive loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
loss = losses.ContrastiveLoss(model=model)
model.fit(
train_objectives=[(train_dataloader, loss)],
epochs=3,
warmup_steps=100
)
# 5. Save model
model.save('models/ad_categorizer')
print("Training complete! Model saved.")