Nexus_NLP_model / final.py
Krish Patel
Trying to resolve the token issue
2f55336
raw
history blame
4.88 kB
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import spacy
import google.generativeai as genai
import json
import os
import dotenv
dotenv.load_dotenv()
# Load spaCy for NER
nlp = spacy.load("en_core_web_sm")
# Load the trained ML model
model_path = "./results/checkpoint-753" # Replace with the actual path to your model
# tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small')
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-small', use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
def setup_gemini():
genai.configure(api_key=os.getenv("GEMINI_API"))
model = genai.GenerativeModel('gemini-pro')
return model
def predict_with_model(text):
"""Predict whether the news is real or fake using the ML model."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1).item()
return "FAKE" if predicted_label == 1 else "REAL"
def extract_entities(text):
"""Extract named entities from text using spaCy."""
doc = nlp(text)
entities = [(ent.text, ent.label_) for ent in doc.ents]
return entities
def predict_news(text):
"""Predict whether the news is real or fake using the ML model."""
# Predict with the ML model
prediction = predict_with_model(text)
return prediction
def analyze_content_gemini(model, text):
prompt = f"""Analyze this news text and return a JSON object with the following structure:
{{
"gemini_analysis": {{
"predicted_classification": "Real or Fake",
"confidence_score": "0-100",
"reasoning": ["point1", "point2"]
}},
"text_classification": {{
"category": "",
"writing_style": "Formal/Informal/Clickbait",
"target_audience": "",
"content_type": "news/opinion/editorial"
}},
"sentiment_analysis": {{
"primary_emotion": "",
"emotional_intensity": "1-10",
"sensationalism_level": "High/Medium/Low",
"bias_indicators": ["bias1", "bias2"],
"tone": {{"formality": "formal/informal", "style": "Professional/Emotional/Neutral"}},
"emotional_triggers": ["trigger1", "trigger2"]
}},
"entity_recognition": {{
"source_credibility": "High/Medium/Low",
"people": ["person1", "person2"],
"organizations": ["org1", "org2"],
"locations": ["location1", "location2"],
"dates": ["date1", "date2"],
"statistics": ["stat1", "stat2"]
}},
"context": {{
"main_narrative": "",
"supporting_elements": ["element1", "element2"],
"key_claims": ["claim1", "claim2"],
"narrative_structure": ""
}},
"fact_checking": {{
"verifiable_claims": ["claim1", "claim2"],
"evidence_present": "Yes/No",
"fact_check_score": "0-100"
}}
}}
Analyze this text and return only the JSON response: {text}"""
response = model.generate_content(prompt)
try:
cleaned_text = response.text.strip()
if cleaned_text.startswith('```json'):
cleaned_text = cleaned_text[7:-3]
return json.loads(cleaned_text)
except json.JSONDecodeError:
return {
"gemini_analysis": {
"predicted_classification": "UNCERTAIN",
"confidence_score": "50",
"reasoning": ["Analysis failed to generate valid JSON"]
}
}
def clean_gemini_output(text):
"""Remove markdown formatting from Gemini output"""
text = text.replace('##', '')
text = text.replace('**', '')
return text
def get_gemini_analysis(text):
"""Get detailed content analysis from Gemini."""
gemini_model = setup_gemini()
gemini_analysis = analyze_content_gemini(gemini_model, text)
return gemini_analysis
def main():
print("Welcome to the News Classifier!")
print("Enter your news text below. Type 'Exit' to quit.")
while True:
news_text = input("\nEnter news text: ")
if news_text.lower() == 'exit':
print("Thank you for using the News Classifier!")
return
# Get ML prediction
prediction = predict_news(news_text)
print(f"\nML Analysis: {prediction}")
# Get Gemini analysis
print("\n=== Detailed Gemini Analysis ===")
gemini_result = get_gemini_analysis(news_text)
print(gemini_result)
if __name__ == "__main__":
main()