sanjid commited on
Commit
4546b35
·
1 Parent(s): b43955f

app change new

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -9,19 +9,21 @@ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
9
  with open("label_types_encoded.json", "r") as fp:
10
  encode_genre_types = json.load(fp)
11
 
 
 
12
  inf_session = rt.InferenceSession('food-classifier-quantized.onnx')
13
  input_name = inf_session.get_inputs()[0].name
14
  output_name = inf_session.get_outputs()[0].name
15
 
16
 
17
- def classify_news_label(article):
18
  input_ids = tokenizer(article)['input_ids'][:512]
19
  logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
20
  logits = torch.FloatTensor(logits)
21
  probs = torch.sigmoid(logits)[0]
22
- return dict(zip(label, map(float, probs)))
23
 
24
 
25
  label = gr.outputs.Label(num_top_classes=6)
26
- iface = gr.Interface(fn=classify_news_label, inputs="text", outputs=label)
27
  iface.launch(inline=False)
 
9
  with open("label_types_encoded.json", "r") as fp:
10
  encode_genre_types = json.load(fp)
11
 
12
+ genres = list(encode_genre_types.keys())
13
+
14
  inf_session = rt.InferenceSession('food-classifier-quantized.onnx')
15
  input_name = inf_session.get_inputs()[0].name
16
  output_name = inf_session.get_outputs()[0].name
17
 
18
 
19
+ def classify_food_Ingredient(article):
20
  input_ids = tokenizer(article)['input_ids'][:512]
21
  logits = inf_session.run([output_name], {input_name: [input_ids]})[0]
22
  logits = torch.FloatTensor(logits)
23
  probs = torch.sigmoid(logits)[0]
24
+ return dict(zip(genres, map(float, probs)))
25
 
26
 
27
  label = gr.outputs.Label(num_top_classes=6)
28
+ iface = gr.Interface(fn=classify_food_Ingredient, inputs="text", outputs=label)
29
  iface.launch(inline=False)