Spaces:
Running
Running
Krish Patel
commited on
Commit
·
f36a10a
0
Parent(s):
Added model and streamlit file
Browse files- .gitattributes +35 -0
- README.md +13 -0
- app.py +64 -0
- final.py +139 -0
- results/checkpoint-753/config.json +35 -0
- results/checkpoint-753/model.safetensors +3 -0
- results/checkpoint-753/optimizer.pt +3 -0
- results/checkpoint-753/rng_state.pth +3 -0
- results/checkpoint-753/scheduler.pt +3 -0
- results/checkpoint-753/trainer_state.json +64 -0
- results/checkpoint-753/training_args.bin +3 -0
- st.py +45 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Nexus NLP Model
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.41.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
short_description: contains nlp model used for truthtell hackathon
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from final import predict_news, get_gemini_analysis
|
5 |
+
import os
|
6 |
+
from tempfile import NamedTemporaryFile
|
7 |
+
|
8 |
+
app = FastAPI()
|
9 |
+
|
10 |
+
# Add CORS middleware
|
11 |
+
app.add_middleware(
|
12 |
+
CORSMiddleware,
|
13 |
+
allow_origins=["http://localhost:5173"], # Your React app's URL
|
14 |
+
allow_credentials=True,
|
15 |
+
allow_methods=["*"],
|
16 |
+
allow_headers=["*"],
|
17 |
+
)
|
18 |
+
|
19 |
+
# Rest of your code remains the same
|
20 |
+
class NewsInput(BaseModel):
|
21 |
+
text: str
|
22 |
+
|
23 |
+
@app.post("/analyze")
|
24 |
+
async def analyze_news(news: NewsInput):
|
25 |
+
prediction = predict_news(news.text)
|
26 |
+
gemini_analysis = get_gemini_analysis(news.text)
|
27 |
+
|
28 |
+
return {
|
29 |
+
"prediction": prediction,
|
30 |
+
"detailed_analysis": gemini_analysis
|
31 |
+
}
|
32 |
+
|
33 |
+
@app.post("/detect-deepfake")
|
34 |
+
async def detect_deepfake(file: UploadFile = File(...)):
|
35 |
+
try:
|
36 |
+
# Save uploaded file temporarily
|
37 |
+
with NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
38 |
+
contents = await file.read()
|
39 |
+
temp_file.write(contents)
|
40 |
+
temp_file_path = temp_file.name
|
41 |
+
|
42 |
+
# Import functions from testing2.py
|
43 |
+
from deepfake2.testing2 import predict_image, predict_video
|
44 |
+
|
45 |
+
# Use appropriate function based on file type
|
46 |
+
if file.filename.lower().endswith('.mp4'):
|
47 |
+
result = predict_video(temp_file_path)
|
48 |
+
file_type = "video"
|
49 |
+
else:
|
50 |
+
result = predict_image(temp_file_path)
|
51 |
+
file_type = "image"
|
52 |
+
|
53 |
+
# Clean up temp file
|
54 |
+
os.remove(temp_file_path)
|
55 |
+
|
56 |
+
return {
|
57 |
+
"result": result,
|
58 |
+
"file_type": file_type
|
59 |
+
}
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
return {"error": str(e)}, 500
|
63 |
+
|
64 |
+
|
final.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
import spacy
|
4 |
+
import google.generativeai as genai
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import dotenv
|
8 |
+
|
9 |
+
dotenv.load_dotenv()
|
10 |
+
|
11 |
+
# Load spaCy for NER
|
12 |
+
nlp = spacy.load("en_core_web_sm")
|
13 |
+
|
14 |
+
# Load the trained ML model
|
15 |
+
model_path = "./results/checkpoint-753" # Replace with the actual path to your model
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
|
17 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
def setup_gemini():
|
21 |
+
genai.configure(api_key=os.getenv("GEMINI_API"))
|
22 |
+
model = genai.GenerativeModel('gemini-pro')
|
23 |
+
return model
|
24 |
+
|
25 |
+
def predict_with_model(text):
|
26 |
+
"""Predict whether the news is real or fake using the ML model."""
|
27 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
28 |
+
with torch.no_grad():
|
29 |
+
outputs = model(**inputs)
|
30 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
31 |
+
predicted_label = torch.argmax(probabilities, dim=-1).item()
|
32 |
+
return "FAKE" if predicted_label == 1 else "REAL"
|
33 |
+
|
34 |
+
def extract_entities(text):
|
35 |
+
"""Extract named entities from text using spaCy."""
|
36 |
+
doc = nlp(text)
|
37 |
+
entities = [(ent.text, ent.label_) for ent in doc.ents]
|
38 |
+
return entities
|
39 |
+
|
40 |
+
def predict_news(text):
|
41 |
+
"""Predict whether the news is real or fake using the ML model."""
|
42 |
+
# Predict with the ML model
|
43 |
+
prediction = predict_with_model(text)
|
44 |
+
return prediction
|
45 |
+
|
46 |
+
def analyze_content_gemini(model, text):
|
47 |
+
prompt = f"""Analyze this news text and return a JSON object with the following structure:
|
48 |
+
{{
|
49 |
+
"gemini_analysis": {{
|
50 |
+
"predicted_classification": "Real or Fake",
|
51 |
+
"confidence_score": "0-100",
|
52 |
+
"reasoning": ["point1", "point2"]
|
53 |
+
}},
|
54 |
+
"text_classification": {{
|
55 |
+
"category": "",
|
56 |
+
"writing_style": "Formal/Informal/Clickbait",
|
57 |
+
"target_audience": "",
|
58 |
+
"content_type": "news/opinion/editorial"
|
59 |
+
}},
|
60 |
+
"sentiment_analysis": {{
|
61 |
+
"primary_emotion": "",
|
62 |
+
"emotional_intensity": "1-10",
|
63 |
+
"sensationalism_level": "High/Medium/Low",
|
64 |
+
"bias_indicators": ["bias1", "bias2"],
|
65 |
+
"tone": {{"formality": "formal/informal", "style": "Professional/Emotional/Neutral"}},
|
66 |
+
"emotional_triggers": ["trigger1", "trigger2"]
|
67 |
+
}},
|
68 |
+
"entity_recognition": {{
|
69 |
+
"source_credibility": "High/Medium/Low",
|
70 |
+
"people": ["person1", "person2"],
|
71 |
+
"organizations": ["org1", "org2"],
|
72 |
+
"locations": ["location1", "location2"],
|
73 |
+
"dates": ["date1", "date2"],
|
74 |
+
"statistics": ["stat1", "stat2"]
|
75 |
+
}},
|
76 |
+
"context": {{
|
77 |
+
"main_narrative": "",
|
78 |
+
"supporting_elements": ["element1", "element2"],
|
79 |
+
"key_claims": ["claim1", "claim2"],
|
80 |
+
"narrative_structure": ""
|
81 |
+
}},
|
82 |
+
"fact_checking": {{
|
83 |
+
"verifiable_claims": ["claim1", "claim2"],
|
84 |
+
"evidence_present": "Yes/No",
|
85 |
+
"fact_check_score": "0-100"
|
86 |
+
}}
|
87 |
+
}}
|
88 |
+
|
89 |
+
Analyze this text and return only the JSON response: {text}"""
|
90 |
+
|
91 |
+
response = model.generate_content(prompt)
|
92 |
+
try:
|
93 |
+
cleaned_text = response.text.strip()
|
94 |
+
if cleaned_text.startswith('```json'):
|
95 |
+
cleaned_text = cleaned_text[7:-3]
|
96 |
+
return json.loads(cleaned_text)
|
97 |
+
except json.JSONDecodeError:
|
98 |
+
return {
|
99 |
+
"gemini_analysis": {
|
100 |
+
"predicted_classification": "UNCERTAIN",
|
101 |
+
"confidence_score": "50",
|
102 |
+
"reasoning": ["Analysis failed to generate valid JSON"]
|
103 |
+
}
|
104 |
+
}
|
105 |
+
|
106 |
+
def clean_gemini_output(text):
|
107 |
+
"""Remove markdown formatting from Gemini output"""
|
108 |
+
text = text.replace('##', '')
|
109 |
+
text = text.replace('**', '')
|
110 |
+
return text
|
111 |
+
|
112 |
+
def get_gemini_analysis(text):
|
113 |
+
"""Get detailed content analysis from Gemini."""
|
114 |
+
gemini_model = setup_gemini()
|
115 |
+
gemini_analysis = analyze_content_gemini(gemini_model, text)
|
116 |
+
return gemini_analysis
|
117 |
+
|
118 |
+
def main():
|
119 |
+
print("Welcome to the News Classifier!")
|
120 |
+
print("Enter your news text below. Type 'Exit' to quit.")
|
121 |
+
|
122 |
+
while True:
|
123 |
+
news_text = input("\nEnter news text: ")
|
124 |
+
|
125 |
+
if news_text.lower() == 'exit':
|
126 |
+
print("Thank you for using the News Classifier!")
|
127 |
+
return
|
128 |
+
|
129 |
+
# Get ML prediction
|
130 |
+
prediction = predict_news(news_text)
|
131 |
+
print(f"\nML Analysis: {prediction}")
|
132 |
+
|
133 |
+
# Get Gemini analysis
|
134 |
+
print("\n=== Detailed Gemini Analysis ===")
|
135 |
+
gemini_result = get_gemini_analysis(news_text)
|
136 |
+
print(gemini_result)
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
results/checkpoint-753/config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/deberta-v3-xsmall",
|
3 |
+
"architectures": [
|
4 |
+
"DebertaV2ForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_dropout_prob": 0.1,
|
9 |
+
"hidden_size": 384,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 1536,
|
12 |
+
"layer_norm_eps": 1e-07,
|
13 |
+
"max_position_embeddings": 512,
|
14 |
+
"max_relative_positions": -1,
|
15 |
+
"model_type": "deberta-v2",
|
16 |
+
"norm_rel_ebd": "layer_norm",
|
17 |
+
"num_attention_heads": 6,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 0,
|
20 |
+
"pooler_dropout": 0,
|
21 |
+
"pooler_hidden_act": "gelu",
|
22 |
+
"pooler_hidden_size": 384,
|
23 |
+
"pos_att_type": [
|
24 |
+
"p2c",
|
25 |
+
"c2p"
|
26 |
+
],
|
27 |
+
"position_biased_input": false,
|
28 |
+
"position_buckets": 256,
|
29 |
+
"relative_attention": true,
|
30 |
+
"share_att_key": true,
|
31 |
+
"torch_dtype": "float32",
|
32 |
+
"transformers_version": "4.46.2",
|
33 |
+
"type_vocab_size": 0,
|
34 |
+
"vocab_size": 128100
|
35 |
+
}
|
results/checkpoint-753/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5c8bc472032aa1625a83fa5a61358b394aa47e8936084fd5d5fc53d39b4819e7
|
3 |
+
size 283347432
|
results/checkpoint-753/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d54a2486861a93c63c9d3f1ad129317a5ec061c153cc35f88750193eb19c8db
|
3 |
+
size 566814714
|
results/checkpoint-753/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bab711e45afdac9084a8d3228aa5d84f0234c10b8536782c428a3e5241e763c0
|
3 |
+
size 14244
|
results/checkpoint-753/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2254eb2782bb8f96d8221a7f05be58b9aa6b59a9ac623c10f2d2cc29c6abdd07
|
3 |
+
size 1064
|
results/checkpoint-753/trainer_state.json
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.13373112678527832,
|
3 |
+
"best_model_checkpoint": "./results\\checkpoint-503",
|
4 |
+
"epoch": 2.99403578528827,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 753,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 0.9980119284294234,
|
13 |
+
"eval_loss": 0.16927649080753326,
|
14 |
+
"eval_runtime": 34.3209,
|
15 |
+
"eval_samples_per_second": 58.623,
|
16 |
+
"eval_steps_per_second": 3.671,
|
17 |
+
"step": 251
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"epoch": 1.9880715705765408,
|
21 |
+
"grad_norm": 3.436805248260498,
|
22 |
+
"learning_rate": 2.53479125248509e-05,
|
23 |
+
"loss": 0.2895,
|
24 |
+
"step": 500
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"epoch": 2.0,
|
28 |
+
"eval_loss": 0.13373112678527832,
|
29 |
+
"eval_runtime": 32.7048,
|
30 |
+
"eval_samples_per_second": 61.52,
|
31 |
+
"eval_steps_per_second": 3.853,
|
32 |
+
"step": 503
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"epoch": 2.99403578528827,
|
36 |
+
"eval_loss": 0.1674525886774063,
|
37 |
+
"eval_runtime": 33.2196,
|
38 |
+
"eval_samples_per_second": 60.567,
|
39 |
+
"eval_steps_per_second": 3.793,
|
40 |
+
"step": 753
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"logging_steps": 500,
|
44 |
+
"max_steps": 753,
|
45 |
+
"num_input_tokens_seen": 0,
|
46 |
+
"num_train_epochs": 3,
|
47 |
+
"save_steps": 500,
|
48 |
+
"stateful_callbacks": {
|
49 |
+
"TrainerControl": {
|
50 |
+
"args": {
|
51 |
+
"should_epoch_stop": false,
|
52 |
+
"should_evaluate": false,
|
53 |
+
"should_log": false,
|
54 |
+
"should_save": true,
|
55 |
+
"should_training_stop": true
|
56 |
+
},
|
57 |
+
"attributes": {}
|
58 |
+
}
|
59 |
+
},
|
60 |
+
"total_flos": 198349894207488.0,
|
61 |
+
"train_batch_size": 16,
|
62 |
+
"trial_name": null,
|
63 |
+
"trial_params": null
|
64 |
+
}
|
results/checkpoint-753/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d020540fc118248e604cd22f9ec20b7acb4023a8953f7fb309148a6a3c3deb8
|
3 |
+
size 5240
|
st.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
|
5 |
+
# Load the model and tokenizer
|
6 |
+
@st.cache_resource
|
7 |
+
def load_model():
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
|
9 |
+
model = AutoModelForSequenceClassification.from_pretrained("./results/checkpoint-753")
|
10 |
+
model.eval()
|
11 |
+
return tokenizer, model
|
12 |
+
|
13 |
+
def predict_news(text, tokenizer, model):
|
14 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
15 |
+
with torch.no_grad():
|
16 |
+
outputs = model(**inputs)
|
17 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
18 |
+
predicted_label = torch.argmax(probabilities, dim=-1).item()
|
19 |
+
confidence = probabilities[0][predicted_label].item()
|
20 |
+
return "FAKE" if predicted_label == 1 else "REAL", confidence
|
21 |
+
|
22 |
+
def main():
|
23 |
+
st.title("News Classifier")
|
24 |
+
|
25 |
+
# Load model
|
26 |
+
tokenizer, model = load_model()
|
27 |
+
|
28 |
+
# Text input
|
29 |
+
news_text = st.text_area("Enter news text to analyze:", height=200)
|
30 |
+
|
31 |
+
if st.button("Classify"):
|
32 |
+
if news_text:
|
33 |
+
with st.spinner('Analyzing...'):
|
34 |
+
prediction, confidence = predict_news(news_text, tokenizer, model)
|
35 |
+
|
36 |
+
# Display results
|
37 |
+
if prediction == "FAKE":
|
38 |
+
st.error(f"⚠️ {prediction} NEWS")
|
39 |
+
else:
|
40 |
+
st.success(f"✅ {prediction} NEWS")
|
41 |
+
|
42 |
+
st.info(f"Confidence: {confidence*100:.2f}%")
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
main()
|