dejanseo commited on
Commit
bf4f788
·
verified ·
1 Parent(s): c2e4a3d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +84 -0
  2. config.toml +6 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertConfig
4
+ import plotly.graph_objects as go
5
+
6
+ # URL of the logo
7
+ logo_url = "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png"
8
+
9
+ # Display the logo at the top using st.logo
10
+ st.logo(logo_url, link="https://dejan.ai")
11
+
12
+ # Streamlit app title and description
13
+ st.title("Search Query Form Classifier")
14
+ st.write("Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries.")
15
+ st.write("Enter a query to check if it's well-formed:")
16
+
17
+ # Load the model and tokenizer from the Hugging Face Model Hub
18
+ model_name = 'dejanseo/Query-Quality-Classifier'
19
+ tokenizer = AlbertTokenizer.from_pretrained(model_name)
20
+ model = AlbertForSequenceClassification.from_pretrained(model_name)
21
+
22
+ # Set the model to evaluation mode
23
+ model.eval()
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ model.to(device)
26
+
27
+ # User input
28
+ user_input = st.text_input("Query:", "What is?")
29
+ st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")
30
+
31
+ def classify_query(query):
32
+ # Tokenize input
33
+ inputs = tokenizer.encode_plus(
34
+ query,
35
+ add_special_tokens=True,
36
+ max_length=32,
37
+ padding='max_length',
38
+ truncation=True,
39
+ return_attention_mask=True,
40
+ return_tensors='pt'
41
+ )
42
+
43
+ input_ids = inputs['input_ids'].to(device)
44
+ attention_mask = inputs['attention_mask'].to(device)
45
+
46
+ # Perform inference
47
+ with torch.no_grad():
48
+ outputs = model(input_ids, attention_mask=attention_mask)
49
+ logits = outputs.logits
50
+ softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
51
+ confidence = softmax_scores[1] * 100 # Confidence for well-formed class
52
+
53
+ return confidence
54
+
55
+ # Check and display classification
56
+ if user_input:
57
+ confidence = classify_query(user_input)
58
+
59
+ # Plotly gauge
60
+ fig = go.Figure(go.Indicator(
61
+ mode="gauge+number",
62
+ value=confidence,
63
+ title={'text': "Well-formedness Confidence"},
64
+ gauge={
65
+ 'axis': {'range': [0, 100]},
66
+ 'bar': {'color': "darkblue"},
67
+ 'steps': [
68
+ {'range': [0, 50], 'color': "red"},
69
+ {'range': [50, 100], 'color': "green"}
70
+ ],
71
+ 'threshold': {
72
+ 'line': {'color': "black", 'width': 4},
73
+ 'thickness': 0.75,
74
+ 'value': confidence
75
+ }
76
+ }
77
+ ))
78
+
79
+ st.plotly_chart(fig)
80
+
81
+ if confidence >= 50:
82
+ st.success(f"The query is likely well-formed with {confidence:.2f}% confidence.")
83
+ else:
84
+ st.error(f"The query is likely not well-formed with {100 - confidence:.2f}% confidence.")
config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [client]
2
+ toolbarMode = "minimal"
3
+ [server]
4
+ headless = true
5
+ enableCORS = false
6
+ port = 8501
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ datasets
5
+ plotly
6
+ sentencepiece