Devishetty100 commited on
Commit
622aff4
·
verified ·
1 Parent(s): f41fa72

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ import json
4
+ import numpy as np
5
+ import re
6
+ from urllib.parse import urlparse
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Define the model and username
11
+ MODEL_NAME = "XGBoost"
12
+ HF_USERNAME = "Devishetty100"
13
+ CUSTOM_MODEL_NAME = "NeoGuardianAI"
14
+ REPO_ID = f"{HF_USERNAME}/{CUSTOM_MODEL_NAME.lower()}"
15
+
16
+ # List of trusted domains that should always be considered safe
17
+ TRUSTED_DOMAINS = [
18
+ 'huggingface.co',
19
+ 'github.com',
20
+ 'google.com',
21
+ 'microsoft.com',
22
+ 'apple.com',
23
+ 'amazon.com',
24
+ 'facebook.com',
25
+ 'twitter.com',
26
+ 'linkedin.com',
27
+ 'youtube.com',
28
+ 'wikipedia.org'
29
+ ]
30
+
31
+ # Load model files (either from local files or Hugging Face Hub)
32
+ def load_model_files():
33
+ try:
34
+ print(f"Attempting to download model from Hugging Face Hub: {REPO_ID}")
35
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=f"{MODEL_NAME.lower()}_model.joblib")
36
+ scaler_path = hf_hub_download(repo_id=REPO_ID, filename="scaler.joblib")
37
+ feature_names_path = hf_hub_download(repo_id=REPO_ID, filename="feature_names.json")
38
+
39
+ # Load the model and preprocessing components
40
+ model = joblib.load(model_path)
41
+ scaler = joblib.load(scaler_path)
42
+
43
+ # Load feature names
44
+ with open(feature_names_path, 'r') as f:
45
+ feature_names = json.load(f)
46
+
47
+ print("Successfully downloaded model from Hugging Face Hub.")
48
+ return model, scaler, feature_names
49
+ except Exception as hub_error:
50
+ print(f"Error downloading from Hugging Face Hub: {hub_error}")
51
+
52
+ # If downloading fails, try to load from local files
53
+ try:
54
+ print("Attempting to load model from local files...")
55
+ model = joblib.load(f"{MODEL_NAME.lower()}_model.joblib")
56
+ scaler = joblib.load("scaler.joblib")
57
+
58
+ with open("feature_names.json", 'r') as f:
59
+ feature_names = json.load(f)
60
+
61
+ print("Successfully loaded model from local files.")
62
+ return model, scaler, feature_names
63
+ except Exception as local_error:
64
+ print(f"Error loading from local files: {local_error}")
65
+ raise RuntimeError("Failed to load model from both Hugging Face Hub and local files.")
66
+
67
+ # Extract features from URL
68
+ def extract_features(url):
69
+ """Extract features from a URL for model prediction."""
70
+ features = {}
71
+
72
+ # Basic URL properties
73
+ features['length_url'] = len(url)
74
+
75
+ # Parse URL
76
+ parsed_url = urlparse(url)
77
+ hostname = parsed_url.netloc
78
+ path = parsed_url.path
79
+
80
+ # Hostname features
81
+ features['length_hostname'] = len(hostname)
82
+ features['ip'] = 1 if re.match(r'\d+\.\d+\.\d+\.\d+', hostname) else 0
83
+
84
+ # Count special characters
85
+ features['nb_dots'] = url.count('.')
86
+ features['nb_hyphens'] = url.count('-')
87
+ features['nb_at'] = url.count('@')
88
+ features['nb_qm'] = url.count('?')
89
+ features['nb_and'] = url.count('&')
90
+ features['nb_or'] = url.count('|')
91
+ features['nb_eq'] = url.count('=')
92
+ features['nb_underscore'] = url.count('_')
93
+ features['nb_tilde'] = url.count('~')
94
+ features['nb_percent'] = url.count('%')
95
+ features['nb_slash'] = url.count('/')
96
+ features['nb_star'] = url.count('*')
97
+ features['nb_colon'] = url.count(':')
98
+ features['nb_comma'] = url.count(',')
99
+ features['nb_semicolumn'] = url.count(';')
100
+ features['nb_dollar'] = url.count('$')
101
+ features['nb_space'] = url.count(' ')
102
+
103
+ # Other URL features
104
+ features['nb_www'] = 1 if 'www' in hostname else 0
105
+ features['nb_com'] = 1 if '.com' in hostname else 0
106
+ features['nb_dslash'] = url.count('//')
107
+ features['http_in_path'] = 1 if 'http' in path else 0
108
+ features['https_token'] = 1 if 'https' in url and 'http://' not in url else 0
109
+
110
+ # Ratio features
111
+ digits_count = sum(c.isdigit() for c in url)
112
+ features['ratio_digits_url'] = digits_count / len(url) if len(url) > 0 else 0
113
+ features['ratio_digits_host'] = sum(c.isdigit() for c in hostname) / len(hostname) if len(hostname) > 0 else 0
114
+
115
+ # Punycode
116
+ features['punycode'] = 1 if 'xn--' in hostname else 0
117
+
118
+ # Port
119
+ features['port'] = 1 if ':' in hostname and any(c.isdigit() for c in hostname.split(':')[1]) else 0
120
+
121
+ # TLD features
122
+ tlds = ['.com', '.org', '.net', '.edu', '.gov', '.mil', '.int']
123
+ features['tld_in_path'] = 1 if any(tld in path for tld in tlds) else 0
124
+ features['tld_in_subdomain'] = 1 if hostname.count('.') > 1 and any(tld in hostname.split('.')[0] for tld in tlds) else 0
125
+
126
+ # Subdomain features
127
+ features['abnormal_subdomain'] = 1 if hostname.count('.') > 2 else 0
128
+ features['nb_subdomains'] = hostname.count('.')
129
+
130
+ # Other suspicious features
131
+ features['prefix_suffix'] = 1 if '-' in hostname else 0
132
+ features['random_domain'] = 1 if len(hostname) > 12 and sum(c.isdigit() for c in hostname) > 4 else 0
133
+
134
+ # Shortening service
135
+ shortening_services = ['bit.ly', 'goo.gl', 'tinyurl.com', 't.co', 'tr.im', 'is.gd', 'cli.gs', 'ow.ly', 'yfrog.com', 'migre.me', 'ff.im', 'tiny.cc', 'url4.eu', 'twit.ac', 'su.pr', 'twurl.nl', 'snipurl.com', 'short.to', 'budurl.com', 'ping.fm', 'post.ly', 'just.as', 'bkite.com', 'snipr.com', 'fic.kr', 'loopt.us', 'doiop.com', 'twitthis.com', 'htxt.it', 'ak.im', 'shar.es', 'kl.am', 'wp.me', 'rubyurl.com', 'om.ly', 'to.ly', 'bit.do', 't.co', 'lnkd.in', 'db.tt', 'qr.ae', 'adf.ly', 'goo.gl', 'bitly.com', 'cur.lv', 'tinyurl.com', 'ow.ly', 'bit.ly', 'ity.im', 'q.gs', 'is.gd', 'po.st', 'bc.vc', 'twitthis.com', 'u.to', 'j.mp', 'buzurl.com', 'cutt.us', 'u.bb', 'yourls.org', 'x.co', 'prettylinkpro.com', 'scrnch.me', 'filoops.info', 'vzturl.com', 'qr.net', '1url.com', 'tweez.me', 'v.gd', 'tr.im', 'link.zip.net']
136
+ features['shortening_service'] = 1 if any(service in hostname for service in shortening_services) else 0
137
+
138
+ # Path features
139
+ features['path_extension'] = 1 if '.' in path.split('/')[-1] else 0
140
+
141
+ # Fill in remaining features with default values
142
+ # These would normally be computed with more complex analysis
143
+ for feature in ['nb_redirection', 'nb_external_redirection', 'length_words_raw',
144
+ 'char_repeat', 'shortest_words_raw', 'shortest_word_host',
145
+ 'shortest_word_path', 'longest_words_raw', 'longest_word_host',
146
+ 'longest_word_path', 'avg_words_raw', 'avg_word_host',
147
+ 'avg_word_path', 'phish_hints', 'domain_in_brand',
148
+ 'brand_in_subdomain', 'brand_in_path', 'suspecious_tld',
149
+ 'statistical_report', 'nb_hyperlinks', 'ratio_intHyperlinks',
150
+ 'ratio_extHyperlinks', 'ratio_nullHyperlinks', 'nb_extCSS',
151
+ 'ratio_intRedirection', 'ratio_extRedirection', 'ratio_intErrors',
152
+ 'ratio_extErrors', 'login_form', 'external_favicon',
153
+ 'links_in_tags', 'submit_email', 'ratio_intMedia',
154
+ 'ratio_extMedia', 'sfh', 'iframe', 'popup_window',
155
+ 'safe_anchor', 'onmouseover', 'right_clic', 'empty_title',
156
+ 'domain_in_title', 'domain_with_copyright', 'whois_registered_domain',
157
+ 'domain_registration_length', 'domain_age', 'web_traffic',
158
+ 'dns_record', 'google_index', 'page_rank']:
159
+ if feature not in features:
160
+ features[feature] = 0
161
+
162
+ return features
163
+
164
+ # Load model and components
165
+ try:
166
+ model, scaler, feature_names = load_model_files()
167
+ print("Model loaded successfully!")
168
+ except Exception as e:
169
+ print(f"Error loading model: {e}")
170
+ # Create dummy model and components for demo purposes
171
+ print("Using dummy model for demonstration purposes.")
172
+ import numpy as np
173
+ from sklearn.ensemble import RandomForestClassifier
174
+
175
+ # Create a dummy model
176
+ model = RandomForestClassifier(n_estimators=10)
177
+ model.fit(np.array([[0, 0]]), np.array([0]))
178
+ model.predict_proba = lambda x: np.array([[0.5, 0.5]])
179
+
180
+ # Create dummy scaler and feature names
181
+ scaler = lambda x: x
182
+ scaler.transform = lambda x: x
183
+ feature_names = ['length_url', 'length_hostname']
184
+
185
+ def predict_url(url):
186
+ """Predict if a URL is phishing or legitimate."""
187
+ if not url or not url.strip():
188
+ return "Please enter a URL", 0.0, "N/A"
189
+
190
+ try:
191
+ # Check if the URL belongs to a trusted domain
192
+ parsed_url = urlparse(url)
193
+ domain = parsed_url.netloc
194
+
195
+ # Remove 'www.' prefix if present
196
+ if domain.startswith('www.'):
197
+ domain = domain[4:]
198
+
199
+ # Check if the domain or any parent domain is in the trusted list
200
+ is_trusted = False
201
+ domain_parts = domain.split('.')
202
+ for i in range(len(domain_parts) - 1):
203
+ check_domain = '.'.join(domain_parts[i:])
204
+ if check_domain in TRUSTED_DOMAINS:
205
+ is_trusted = True
206
+ break
207
+
208
+ if is_trusted:
209
+ return "Legitimate (Trusted Domain)", 1.0, "✅ SAFE"
210
+
211
+ # Extract features
212
+ url_features = extract_features(url)
213
+
214
+ # Ensure features are in the correct order
215
+ features_array = []
216
+ for feature in feature_names:
217
+ if feature in url_features:
218
+ features_array.append(url_features[feature])
219
+ else:
220
+ features_array.append(0) # Default value if feature is missing
221
+
222
+ # Scale features
223
+ scaled_features = scaler.transform([features_array])
224
+
225
+ # Make prediction
226
+ prediction = model.predict(scaled_features)[0]
227
+ probability = model.predict_proba(scaled_features)[0][1]
228
+
229
+ # Prepare return values
230
+ prediction_text = "Phishing" if prediction == 1 else "Legitimate"
231
+ confidence = float(probability) if prediction == 1 else float(1 - probability)
232
+ status = "⚠️ UNSAFE" if prediction == 1 else "✅ SAFE"
233
+
234
+ # Return three separate values for the three output components
235
+ return prediction_text, confidence, status
236
+ except Exception as e:
237
+ error_msg = f"Error: {str(e)}"
238
+ return error_msg, 0.0, "Error"
239
+
240
+ # Create Gradio interface
241
+ def create_interface():
242
+ with gr.Blocks(title="NeoGuardianAI - URL Phishing Detection", theme=gr.themes.Soft()) as demo:
243
+ gr.Markdown(
244
+ """
245
+ # NeoGuardianAI - URL Phishing Detection
246
+
247
+ This app uses a machine learning model to detect if a URL is legitimate or phishing.
248
+
249
+ Enter a URL below to check if it's safe or potentially malicious.
250
+ """
251
+ )
252
+
253
+ with gr.Row():
254
+ url_input = gr.Textbox(label="Enter URL", placeholder="https://example.com")
255
+ submit_btn = gr.Button("Check URL", variant="primary")
256
+
257
+ with gr.Row():
258
+ status_output = gr.Textbox(label="Status")
259
+ prediction_output = gr.Textbox(label="Prediction")
260
+ confidence_output = gr.Textbox(label="Confidence")
261
+
262
+ submit_btn.click(
263
+ fn=predict_url,
264
+ inputs=url_input,
265
+ outputs=[
266
+ prediction_output,
267
+ confidence_output,
268
+ status_output
269
+ ]
270
+ )
271
+
272
+ gr.Markdown(
273
+ """
274
+ ## How it works
275
+
276
+ This model was trained on the [pirocheto/phishing-url](https://huggingface.co/datasets/pirocheto/phishing-url) dataset from Hugging Face.
277
+
278
+ The model extracts various features from the URL and uses a machine learning algorithm to classify it as legitimate or phishing.
279
+
280
+ **Note**: While this model is highly accurate, it's not perfect. Always exercise caution when visiting unfamiliar websites.
281
+
282
+ ## API Usage
283
+
284
+ You can also use this model via the Hugging Face Inference API:
285
+
286
+ ```python
287
+ import requests
288
+
289
+ API_URL = "https://api-inference.huggingface.co/models/Devishetty100/neoguardianai"
290
+ headers = {"Authorization": "Bearer YOUR_API_TOKEN"}
291
+
292
+ def query(url):
293
+ payload = {"inputs": url}
294
+ response = requests.post(API_URL, headers=headers, json=payload)
295
+ return response.json()
296
+
297
+ # Example
298
+ result = query("https://example.com")
299
+ print(result)
300
+ ```
301
+ """
302
+ )
303
+
304
+ return demo
305
+
306
+ # Launch the app
307
+ if __name__ == "__main__":
308
+ demo = create_interface()
309
+ demo.launch()