Nexus_NLP_model / app.py
Krish Patel
try2
e9db226
raw
history blame
9.31 kB
import streamlit as st
import random
# Page configuration
st.set_page_config(
page_title="Nexus NLP News Classifier"
)
import pandas as pd
from final import *
from pydantic import BaseModel
import plotly.graph_objects as go
# Update the initialize_models function
@st.cache_resource
def initialize_models():
try:
nlp = spacy.load("en_core_web_sm")
except:
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
model_path = "./results/checkpoint-753"
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-small')
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
knowledge_graph = load_knowledge_graph()
return nlp, tokenizer, model, knowledge_graph
class NewsInput(BaseModel):
text: str
def generate_knowledge_graph_viz(text, nlp, tokenizer, model):
kg_builder = KnowledgeGraphBuilder()
# Get prediction
prediction, _ = predict_with_model(text, tokenizer, model)
is_fake = prediction == "FAKE"
# Update knowledge graph
kg_builder.update_knowledge_graph(text, not is_fake, nlp)
# Randomly select subset of edges (e.g. 60% of edges)
edges = list(kg_builder.knowledge_graph.edges())
selected_edges = random.sample(edges, k=int(len(edges) * 0.3))
# Create a new graph with selected edges
selected_graph = nx.DiGraph()
selected_graph.add_nodes_from(kg_builder.knowledge_graph.nodes(data=True))
selected_graph.add_edges_from(selected_edges)
pos = nx.spring_layout(selected_graph)
edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
),
hoverinfo='none',
mode='lines'
)
# Create visualization
pos = nx.spring_layout(kg_builder.knowledge_graph)
edge_trace = go.Scatter(
x=[], y=[],
line=dict(
width=2,
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
),
hoverinfo='none',
mode='lines'
)
node_trace = go.Scatter(
x=[], y=[],
mode='markers+text',
hoverinfo='text',
textposition='top center',
marker=dict(
size=15,
color='white',
line=dict(width=2, color='black')
),
text=[]
)
# Add edges
for edge in selected_graph.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace['x'] += (x0, x1, None)
edge_trace['y'] += (y0, y1, None)
# Add nodes
for node in kg_builder.knowledge_graph.nodes():
x, y = pos[node]
node_trace['x'] += (x,)
node_trace['y'] += (y,)
node_trace['text'] += (node,)
fig = go.Figure(
data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)'
)
)
return fig
# Streamlit UI
def main():
st.title("Nexus NLP News Classifier")
st.write("Enter news text below to analyze its authenticity")
# Initialize models
nlp, tokenizer, model, knowledge_graph = initialize_models()
# Text input area
news_text = st.text_area("News Text", height=200)
if st.button("Analyze"):
if news_text:
with st.spinner("Analyzing..."):
# Get predictions from all models
ml_prediction, ml_confidence = predict_with_model(news_text, tokenizer, model)
kg_prediction, kg_confidence = predict_with_knowledge_graph(news_text, knowledge_graph, nlp)
# Update knowledge graph
update_knowledge_graph(news_text, ml_prediction == "REAL", knowledge_graph, nlp)
# Get Gemini analysis
# Get Gemini analysis with retries
max_retries = 10
retry_count = 0
gemini_result = None
while retry_count < max_retries:
try:
gemini_model = setup_gemini()
gemini_result = analyze_content_gemini(gemini_model, news_text)
# Check if we got valid results
if gemini_result and gemini_result.get('gemini_analysis'):
break
except Exception:
pass
retry_count += 1
# Use default values if all retries failed
if not gemini_result:
gemini_result = {
"gemini_analysis": {
"predicted_classification": "UNCERTAIN",
"confidence_score": "50",
"reasoning": ["Analysis temporarily unavailable"]
}
}
# Display metrics in columns
col1 = st.columns(1)[0]
with col1:
st.subheader("ML Model and Knowedge Graph Analysis")
st.metric("Prediction", ml_prediction)
st.metric("Confidence", f"{ml_confidence:.2f}%")
# with col2:
# st.subheader("Knowledge Graph Analysis")
# st.metric("Prediction", kg_prediction)
# st.metric("Confidence", f"{kg_confidence:.2f}%")
# with col3:
# st.subheader("Gemini Analysis")
# gemini_pred = gemini_result["gemini_analysis"]["predicted_classification"]
# gemini_conf = gemini_result["gemini_analysis"]["confidence_score"]
# st.metric("Prediction", gemini_pred)
# st.metric("Confidence", f"{gemini_conf}%")
# Single expander for all analysis details
with st.expander("Detailed Analysis"):
try:
# Text Classification
st.subheader("πŸ“ Text Classification")
text_class = gemini_result.get('text_classification', {})
st.write(f"Category: {text_class.get('category', 'N/A')}")
st.write(f"Writing Style: {text_class.get('writing_style', 'N/A')}")
st.write(f"Target Audience: {text_class.get('target_audience', 'N/A')}")
st.write(f"Content Type: {text_class.get('content_type', 'N/A')}")
# Sentiment Analysis
st.subheader("🎭 Sentiment Analysis")
sentiment = gemini_result.get('sentiment_analysis', {})
st.write(f"Primary Emotion: {sentiment.get('primary_emotion', 'N/A')}")
st.write(f"Emotional Intensity: {sentiment.get('emotional_intensity', 'N/A')}/10")
st.write(f"Sensationalism Level: {sentiment.get('sensationalism_level', 'N/A')}")
st.write("Bias Indicators:", ", ".join(sentiment.get('bias_indicators', ['N/A'])))
# Entity Recognition
st.subheader("πŸ” Entity Recognition")
entities = gemini_result.get('entity_recognition', {})
st.write(f"Source Credibility: {entities.get('source_credibility', 'N/A')}")
st.write("People:", ", ".join(entities.get('people', ['N/A'])))
st.write("Organizations:", ", ".join(entities.get('organizations', ['N/A'])))
# Named Entities from spaCy
st.subheader("🏷️ Named Entities")
entities = extract_entities(news_text, nlp)
df = pd.DataFrame(entities, columns=["Entity", "Type"])
st.dataframe(df)
# Knowledge Graph Visualization
st.subheader("πŸ•ΈοΈ Knowledge Graph")
fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model)
st.plotly_chart(fig, use_container_width=True)
# Analysis Reasoning
st.subheader("πŸ’­ Analysis Reasoning")
for point in gemini_result.get('gemini_analysis', {}).get('reasoning', ['N/A']):
st.write(f"β€’ {point}")
except Exception as e:
st.error("Error processing analysis results")
else:
st.warning("Please enter some text to analyze")
if __name__ == "__main__":
main()