jonathanjordan21 commited on
Commit
00820a7
·
verified ·
1 Parent(s): 6fb057b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
 
6
  from sentence_transformers import SentenceTransformer
7
  from sentence_transformers.util import cos_sim
 
8
 
9
 
10
  codes = """001 - Vehicle Registration (New)
@@ -361,6 +362,8 @@ model_ids = [
361
  "sentence-transformers/distiluse-base-multilingual-cased-v2",
362
  "Alibaba-NLP/gte-multilingual-base",
363
  "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
 
 
364
  ]
365
  # model_id = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
366
  # model_id = "Alibaba-NLP/gte-multilingual-base"
@@ -368,8 +371,18 @@ model_ids = [
368
  # model_id = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
369
  # model_id = "intfloat/multilingual-e5-small"
370
  # model_id = "sentence-transformers/distiluse-base-multilingual-cased-v2"
 
371
  model_id = model_ids[-1]
372
- model = SentenceTransformer(model_id, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
373
 
374
  # codes_emb = model.encode([x[6:] for x in codes])
375
  codes_emb = model.encode([x["examples"] for x in examples])#.mean(axis=1)
@@ -497,9 +510,15 @@ def respond(
497
  plates = [" ".join(x).upper() for i,x in enumerate(matches)]
498
 
499
  plate_numbers = ", ".join(plates)
500
-
501
- text_emb = model.encode(message)
502
- scores = cos_sim(codes_emb, text_emb).mean(axis=-1)#[:,0]
 
 
 
 
 
 
503
 
504
  scores_argsort = scores.argsort(descending=True)
505
  weights = [18,8,7,6,5,4,3,2,1]
 
5
 
6
  from sentence_transformers import SentenceTransformer
7
  from sentence_transformers.util import cos_sim
8
+ from sentence_transformers import CrossEncoder
9
 
10
 
11
  codes = """001 - Vehicle Registration (New)
 
362
  "sentence-transformers/distiluse-base-multilingual-cased-v2",
363
  "Alibaba-NLP/gte-multilingual-base",
364
  "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
365
+ "BAAI/bge-reranker-v2-m3",
366
+ "jinaai/jina-reranker-v2-base-multilingual"
367
  ]
368
  # model_id = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
369
  # model_id = "Alibaba-NLP/gte-multilingual-base"
 
371
  # model_id = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
372
  # model_id = "intfloat/multilingual-e5-small"
373
  # model_id = "sentence-transformers/distiluse-base-multilingual-cased-v2"
374
+
375
  model_id = model_ids[-1]
376
+
377
+ if model_id in model_ids[-2:]:
378
+ model = CrossEncoder(
379
+ # "jinaai/jina-reranker-v2-base-multilingual",
380
+ "BAAI/bge-reranker-v2-m3",
381
+ automodel_args={"torch_dtype": "auto"},
382
+ trust_remote_code=True,
383
+ )
384
+ else:
385
+ model = SentenceTransformer(model_id, trust_remote_code=True)
386
 
387
  # codes_emb = model.encode([x[6:] for x in codes])
388
  codes_emb = model.encode([x["examples"] for x in examples])#.mean(axis=1)
 
510
  plates = [" ".join(x).upper() for i,x in enumerate(matches)]
511
 
512
  plate_numbers = ", ".join(plates)
513
+
514
+ if model.config._name_or_path in model_ids[-2:]:
515
+ # documents = [v["name"] for v in detail_perhitungan.values()]
516
+ sentence_pairs = [[message, v["name"]] for v in detail_perhitungan.values()]
517
+ scores = model.predict(sentence_pairs, convert_to_tensor=True)
518
+ # scores = [x["score"] for x in model.rank(message, documents)]
519
+ else:
520
+ text_emb = model.encode(message)
521
+ scores = cos_sim(codes_emb, text_emb).mean(axis=-1)#[:,0]
522
 
523
  scores_argsort = scores.argsort(descending=True)
524
  weights = [18,8,7,6,5,4,3,2,1]