sdadas commited on
Commit
53da7de
1 Parent(s): b9bed78

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -3
README.md CHANGED
@@ -36,12 +36,16 @@ answers = [
36
  model_name = "sdadas/polish-reranker-roberta-v2"
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
  model = AutoModelForSequenceClassification.from_pretrained(
39
- model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
 
 
 
 
40
  )
41
  texts = [f"{query}</s></s>{answer}" for answer in answers]
42
- tokens = tokenizer(texts, padding="longest", max_length=512, truncation=True, return_tensors="pt")
43
  output = model(**tokens)
44
- results = output.logits.detach().numpy()
45
  results = np.squeeze(results)
46
  print(results.tolist())
47
  ```
 
36
  model_name = "sdadas/polish-reranker-roberta-v2"
37
  tokenizer = AutoTokenizer.from_pretrained(model_name)
38
  model = AutoModelForSequenceClassification.from_pretrained(
39
+ model_name,
40
+ trust_remote_code=True,
41
+ torch_dtype=torch.bfloat16,
42
+ attn_implementation="flash_attention_2",
43
+ device_map="cuda"
44
  )
45
  texts = [f"{query}</s></s>{answer}" for answer in answers]
46
+ tokens = tokenizer(texts, padding="longest", max_length=512, truncation=True, return_tensors="pt").to("cuda")
47
  output = model(**tokens)
48
+ results = output.logits.detach().cpu().float().numpy()
49
  results = np.squeeze(results)
50
  print(results.tolist())
51
  ```