kithangw commited on
Commit
66019c8
·
verified ·
1 Parent(s): 5f23af7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -28
app.py CHANGED
@@ -7,39 +7,27 @@ from transformers import pipeline, AutoModelForSequenceClassification, AutoToken
7
  @st.cache(allow_output_mutation=True)
8
  def load_models():
9
  image_pipeline = pipeline("image-to-text", model="microsoft/trocr-large-printed")
10
- phishing_model = AutoModelForSequenceClassification.from_pretrained("kithangw/phishing_link_detection", num_labels=2)
11
  phishing_tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
12
  return image_pipeline, phishing_model, phishing_tokenizer
13
 
14
  image_pipeline, phishing_model, phishing_tokenizer = load_models()
15
 
16
- # Define the main function
17
- def main(image_input):
18
- # Convert image to URL text
19
- def image2url(image_input):
20
- url_for_recognise = image_pipeline(image_input)[0]['generated_text'].replace(" ", "").lower()
21
- st.write(f"Recognized URL: {url_for_recognise}")
22
- return url_for_recognise
23
-
24
- # Check if the URL text is a phishing link
25
- def checkphishing(url_for_recognise):
26
- link_token = phishing_tokenizer(url_for_recognise, max_length=512, padding=True, truncation=True, return_tensors='pt')
27
 
28
- with torch.no_grad(): # Disable gradient calculation for inference
29
- output = phishing_model(**link_token)
30
 
31
- probabilities = torch.nn.functional.softmax(output.logits, dim=-1)
32
- predicted_class = torch.argmax(probabilities, dim=-1).item()
33
- predicted_prob = probabilities[0, predicted_class].item()
34
 
35
- labels = ['Not Phishing', 'Phishing']
36
- prediction_label = labels[predicted_class]
37
- sentence = f"The URL '{url_for_recognise}' is classified as '{prediction_label}' with a probability of {predicted_prob:.2f}."
38
- return sentence
39
-
40
- url_text = image2url(image_input)
41
- result_sentence = checkphishing(url_text)
42
- return result_sentence
43
 
44
  # Streamlit interface
45
  st.title("Phishing URL Detection from Image")
@@ -48,6 +36,16 @@ uploaded_image = st.file_uploader("Upload an image of the URL", type=["png", "jp
48
  if uploaded_image is not None:
49
  image = Image.open(uploaded_image)
50
  st.image(image, caption='Uploaded URL Image', use_column_width=True)
51
- if st.button('Detect'):
52
- result = main(uploaded_image)
53
- st.write(result)
 
 
 
 
 
 
 
 
 
 
 
7
  @st.cache(allow_output_mutation=True)
8
  def load_models():
9
  image_pipeline = pipeline("image-to-text", model="microsoft/trocr-large-printed")
10
+ phishing_model = AutoModelForSequenceClassification.from_pretrained("kithangw/phishing_link_detection")", num_labels=2)
11
  phishing_tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
12
  return image_pipeline, phishing_model, phishing_tokenizer
13
 
14
  image_pipeline, phishing_model, phishing_tokenizer = load_models()
15
 
16
+ # Define the phishing check function
17
+ def check_phishing(url_for_recognize):
18
+ link_token = phishing_tokenizer(url_for_recognize, max_length=512, padding=True, truncation=True, return_tensors='pt')
 
 
 
 
 
 
 
 
19
 
20
+ with torch.no_grad(): # Disable gradient calculation for inference
21
+ output = phishing_model(**link_token)
22
 
23
+ probabilities = torch.nn.functional.softmax(output.logits, dim=-1)
24
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
25
+ predicted_prob = probabilities[0, predicted_class].item()
26
 
27
+ labels = ['Not Phishing', 'Phishing']
28
+ prediction_label = labels[predicted_class]
29
+ sentence = f"The URL '{url_for_recognize}' is classified as '{prediction_label}' with a probability of {predicted_prob:.2f}."
30
+ return sentence
 
 
 
 
31
 
32
  # Streamlit interface
33
  st.title("Phishing URL Detection from Image")
 
36
  if uploaded_image is not None:
37
  image = Image.open(uploaded_image)
38
  st.image(image, caption='Uploaded URL Image', use_column_width=True)
39
+
40
+ # Convert image to URL text
41
+ url_for_recognize = image_pipeline(uploaded_image)[0]['generated_text'].replace(" ", "").lower()
42
+ st.write("Recognized URL:")
43
+ # Use a text input to let the user verify and possibly edit the recognized URL
44
+ verified_url = st.text_input("Verify or edit the recognized URL if necessary:", value=url_for_recognize)
45
+
46
+ if st.button('Detect Phishing'):
47
+ if verified_url:
48
+ result = check_phishing(verified_url)
49
+ st.write(result)
50
+ else:
51
+ st.write("Please enter a URL to check for phishing.")