MLDeveloper commited on
Commit
ed1b0c1
·
verified ·
1 Parent(s): 4a00c88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -5
app.py CHANGED
@@ -2,9 +2,15 @@ import streamlit as st
2
  import pandas as pd
3
  import re
4
  import string
 
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.naive_bayes import MultinomialNB
 
 
 
 
 
8
 
9
  # Title & Intro
10
  st.set_page_config(page_title="SMS Spam Detection", layout="centered")
@@ -27,7 +33,6 @@ df['label'] = df['label'].map({'ham': 0, 'spam': 1})
27
 
28
  # --- Train Model ---
29
  X_train, X_test, y_train, y_test = train_test_split(df['message'], df['label'], test_size=0.2, random_state=42)
30
-
31
  vectorizer = TfidfVectorizer()
32
  X_train_tfidf = vectorizer.fit_transform(X_train)
33
 
@@ -51,6 +56,18 @@ def predict_spam(text):
51
  prediction = model.predict(vector)
52
  return "Spam" if prediction[0] == 1 else "Not Spam (Ham)"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # --- Input ---
55
  user_input = st.text_area("✉️ Enter your SMS message here:")
56
 
@@ -58,11 +75,25 @@ if st.button("Check Message"):
58
  if user_input.strip() == "":
59
  st.warning("⚠️ Please enter a message.")
60
  else:
61
- result = predict_spam(user_input)
62
- if result == "Spam":
63
- st.error("🚫 This message is classified as **SPAM**.")
 
 
 
 
 
 
 
 
 
 
64
  else:
65
- st.success("✅ This message is classified as **NOT SPAM (HAM)**.")
 
 
 
 
66
 
67
  # --- Dataset preview ---
68
  with st.expander("📄 View sample dataset"):
 
2
  import pandas as pd
3
  import re
4
  import string
5
+ import google.generativeai as genai
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.feature_extraction.text import TfidfVectorizer
8
  from sklearn.naive_bayes import MultinomialNB
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+
11
+ # --- Set Gemini API Key ---
12
+ genai.configure(api_key="AIzaSyCVRGVxIe1vESoAgykgHWOej-jZxiU-RKE") # <-- Replace this with your actual Gemini API key
13
+ gemini_model = genai.GenerativeModel("gemini-pro")
14
 
15
  # Title & Intro
16
  st.set_page_config(page_title="SMS Spam Detection", layout="centered")
 
33
 
34
  # --- Train Model ---
35
  X_train, X_test, y_train, y_test = train_test_split(df['message'], df['label'], test_size=0.2, random_state=42)
 
36
  vectorizer = TfidfVectorizer()
37
  X_train_tfidf = vectorizer.fit_transform(X_train)
38
 
 
56
  prediction = model.predict(vector)
57
  return "Spam" if prediction[0] == 1 else "Not Spam (Ham)"
58
 
59
+ # --- Gemini Fallback ---
60
+ def ask_gemini(text):
61
+ prompt = f"""You are an expert SMS spam detector.
62
+ Classify the following message as 'Spam' or 'Not Spam (Ham)'.
63
+ Message: "{text}"
64
+ Reply with only: Spam or Not Spam (Ham)."""
65
+ try:
66
+ response = gemini_model.generate_content(prompt)
67
+ return response.text.strip()
68
+ except Exception as e:
69
+ return f"Error using Gemini: {str(e)}"
70
+
71
  # --- Input ---
72
  user_input = st.text_area("✉️ Enter your SMS message here:")
73
 
 
75
  if user_input.strip() == "":
76
  st.warning("⚠️ Please enter a message.")
77
  else:
78
+ cleaned = clean_text(user_input)
79
+ input_vector = vectorizer.transform([cleaned])
80
+ similarities = cosine_similarity(input_vector, X_train_tfidf)
81
+ max_similarity = similarities.max()
82
+
83
+ # Check similarity threshold (e.g., < 0.3 = unknown message)
84
+ if max_similarity < 0.3:
85
+ st.info("🧠 Message not found in training data. Using Gemini for prediction...")
86
+ gemini_result = ask_gemini(user_input)
87
+ if "spam" in gemini_result.lower():
88
+ st.error("🚫 Gemini says: This message is **SPAM**.")
89
+ else:
90
+ st.success("✅ Gemini says: This message is **NOT SPAM (HAM)**.")
91
  else:
92
+ result = predict_spam(user_input)
93
+ if result == "Spam":
94
+ st.error("🚫 This message is classified as **SPAM**.")
95
+ else:
96
+ st.success("✅ This message is classified as **NOT SPAM (HAM)**.")
97
 
98
  # --- Dataset preview ---
99
  with st.expander("📄 View sample dataset"):