bertugmirasyedi commited on
Commit
eb15969
·
1 Parent(s): 4021af8

Took the model definitions outside the functions

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -1,7 +1,13 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import os
4
-
 
 
 
 
 
 
5
 
6
  # Define the FastAPI app
7
  app = FastAPI(docs_url="/")
@@ -15,8 +21,32 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
 
18
  key = os.environ.get("GOOGLE_BOOKS_API_KEY")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @app.get("/search")
22
  async def search(
@@ -329,23 +359,14 @@ async def classify(data: list, runtime: str = "normal"):
329
  pipeline,
330
  )
331
  from optimum.onnxruntime import ORTModelForSequenceClassification
332
- from optimum.bettertransformer import BetterTransformer
333
 
334
  if runtime == "normal":
335
  # Define the zero-shot classifier
336
- tokenizer = AutoTokenizer.from_pretrained(
337
- "sileod/deberta-v3-base-tasksource-nli"
338
- )
339
- model = AutoModelForSequenceClassification.from_pretrained(
340
- "sileod/deberta-v3-base-tasksource-nli"
341
- )
342
  elif runtime == "onnxruntime":
343
- tokenizer = AutoTokenizer.from_pretrained(
344
- "optimum/distilbert-base-uncased-mnli"
345
- )
346
- model = ORTModelForSequenceClassification.from_pretrained(
347
- "optimum/distilbert-base-uncased-mnli"
348
- )
349
 
350
  classifier_pipe = pipeline(
351
  "zero-shot-classification",
@@ -396,7 +417,7 @@ async def find_similar(data: list, top_k: int = 5):
396
  for title, description, publisher in zip(titles, descriptions, publishers)
397
  ]
398
 
399
- sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
400
  book_embeddings = sentence_transformer.encode(combined_data, convert_to_tensor=True)
401
 
402
  # Make sure that the top_k value is not greater than the number of books
@@ -438,12 +459,12 @@ async def summarize(descriptions: list, runtime="normal"):
438
 
439
  # Define the summarizer model and tokenizer
440
  if runtime == "normal":
441
- tokenizer = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
442
- model = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
443
  model = BetterTransformer.transform(model)
444
  elif runtime == "onnxruntime":
445
- tokenizer = AutoTokenizer.from_pretrained("optimum/t5-small")
446
- model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
447
 
448
  # Create the summarizer pipeline
449
  summarizer_pipe = pipeline("summarization", model=model, tokenizer=tokenizer)
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  import os
4
+ from transformers import (
5
+ AutoModelForSeq2SeqLM,
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ )
9
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM, ORTModelForSequenceClassification
10
+ from sentence_transformers import SentenceTransformer
11
 
12
  # Define the FastAPI app
13
  app = FastAPI(docs_url="/")
 
21
  allow_headers=["*"],
22
  )
23
 
24
+ # Define the Google Books API key
25
  key = os.environ.get("GOOGLE_BOOKS_API_KEY")
26
 
27
+ # Define summarization models
28
+ summary_tokenizer_normal = AutoTokenizer.from_pretrained("lidiya/bart-base-samsum")
29
+ summary_model_normal = AutoModelForSeq2SeqLM.from_pretrained("lidiya/bart-base-samsum")
30
+ summary_tokenizer_onnx = AutoTokenizer.from_pretrained("optimum/t5-small")
31
+ summary_model_onnx = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small")
32
+
33
+ # Define classification models
34
+ classification_tokenizer_normal = AutoTokenizer.from_pretrained(
35
+ "sileod/deberta-v3-base-tasksource-nli"
36
+ )
37
+ classification_model_normal = AutoModelForSequenceClassification.from_pretrained(
38
+ "sileod/deberta-v3-base-tasksource-nli"
39
+ )
40
+ classification_tokenizer_onnx = AutoTokenizer.from_pretrained(
41
+ "optimum/distilbert-base-uncased-mnli"
42
+ )
43
+ classification_model_onnx = ORTModelForSequenceClassification.from_pretrained(
44
+ "optimum/distilbert-base-uncased-mnli"
45
+ )
46
+
47
+ # Define similarity model
48
+ similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
49
+
50
 
51
  @app.get("/search")
52
  async def search(
 
359
  pipeline,
360
  )
361
  from optimum.onnxruntime import ORTModelForSequenceClassification
 
362
 
363
  if runtime == "normal":
364
  # Define the zero-shot classifier
365
+ tokenizer = classification_tokenizer_normal
366
+ model = classification_model_normal
 
 
 
 
367
  elif runtime == "onnxruntime":
368
+ tokenizer = classification_tokenizer_onnx
369
+ model = classification_model_onnx
 
 
 
 
370
 
371
  classifier_pipe = pipeline(
372
  "zero-shot-classification",
 
417
  for title, description, publisher in zip(titles, descriptions, publishers)
418
  ]
419
 
420
+ sentence_transformer = similarity_model
421
  book_embeddings = sentence_transformer.encode(combined_data, convert_to_tensor=True)
422
 
423
  # Make sure that the top_k value is not greater than the number of books
 
459
 
460
  # Define the summarizer model and tokenizer
461
  if runtime == "normal":
462
+ tokenizer = summary_tokenizer_normal
463
+ model = summary_model_normal
464
  model = BetterTransformer.transform(model)
465
  elif runtime == "onnxruntime":
466
+ tokenizer = summary_tokenizer_onnx
467
+ model = summary_model_onnx
468
 
469
  # Create the summarizer pipeline
470
  summarizer_pipe = pipeline("summarization", model=model, tokenizer=tokenizer)