Spaces:
Running
Running
Krish Patel
commited on
Commit
·
5e58061
1
Parent(s):
c422cb8
Add ploty for knowledge graph
Browse files- app.py +75 -0
- final.py +35 -0
- requirements.txt +2 -0
app.py
CHANGED
@@ -50,6 +50,8 @@ st.set_page_config(
|
|
50 |
|
51 |
import pandas as pd
|
52 |
from final import *
|
|
|
|
|
53 |
|
54 |
# # Cache model loading
|
55 |
# @st.cache_resource
|
@@ -78,6 +80,75 @@ def initialize_models():
|
|
78 |
knowledge_graph = load_knowledge_graph()
|
79 |
return nlp, tokenizer, model, knowledge_graph
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
# Streamlit UI
|
82 |
def main():
|
83 |
st.title("Nexus NLP News Classifier")
|
@@ -132,6 +203,10 @@ def main():
|
|
132 |
df = pd.DataFrame(entities, columns=["Entity", "Type"])
|
133 |
st.dataframe(df)
|
134 |
|
|
|
|
|
|
|
|
|
135 |
else:
|
136 |
st.warning("Please enter some text to analyze")
|
137 |
|
|
|
50 |
|
51 |
import pandas as pd
|
52 |
from final import *
|
53 |
+
from pydantic import BaseModel
|
54 |
+
import plotly.graph_objects as go
|
55 |
|
56 |
# # Cache model loading
|
57 |
# @st.cache_resource
|
|
|
80 |
knowledge_graph = load_knowledge_graph()
|
81 |
return nlp, tokenizer, model, knowledge_graph
|
82 |
|
83 |
+
|
84 |
+
class NewsInput(BaseModel):
|
85 |
+
text: str
|
86 |
+
|
87 |
+
def generate_knowledge_graph_viz(text, nlp, tokenizer, model):
|
88 |
+
kg_builder = KnowledgeGraphBuilder()
|
89 |
+
|
90 |
+
# Get prediction
|
91 |
+
prediction, _ = predict_with_model(text, tokenizer, model)
|
92 |
+
is_fake = prediction == "FAKE"
|
93 |
+
|
94 |
+
# Update knowledge graph
|
95 |
+
kg_builder.update_knowledge_graph(text, not is_fake, nlp)
|
96 |
+
|
97 |
+
# Create visualization
|
98 |
+
pos = nx.spring_layout(kg_builder.knowledge_graph)
|
99 |
+
|
100 |
+
edge_trace = go.Scatter(
|
101 |
+
x=[], y=[],
|
102 |
+
line=dict(
|
103 |
+
width=2,
|
104 |
+
color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
|
105 |
+
),
|
106 |
+
hoverinfo='none',
|
107 |
+
mode='lines'
|
108 |
+
)
|
109 |
+
|
110 |
+
node_trace = go.Scatter(
|
111 |
+
x=[], y=[],
|
112 |
+
mode='markers+text',
|
113 |
+
hoverinfo='text',
|
114 |
+
textposition='top center',
|
115 |
+
marker=dict(
|
116 |
+
size=15,
|
117 |
+
color='white',
|
118 |
+
line=dict(width=2, color='black')
|
119 |
+
),
|
120 |
+
text=[]
|
121 |
+
)
|
122 |
+
|
123 |
+
# Add edges
|
124 |
+
for edge in kg_builder.knowledge_graph.edges():
|
125 |
+
x0, y0 = pos[edge[0]]
|
126 |
+
x1, y1 = pos[edge[1]]
|
127 |
+
edge_trace['x'] += (x0, x1, None)
|
128 |
+
edge_trace['y'] += (y0, y1, None)
|
129 |
+
|
130 |
+
# Add nodes
|
131 |
+
for node in kg_builder.knowledge_graph.nodes():
|
132 |
+
x, y = pos[node]
|
133 |
+
node_trace['x'] += (x,)
|
134 |
+
node_trace['y'] += (y,)
|
135 |
+
node_trace['text'] += (node,)
|
136 |
+
|
137 |
+
fig = go.Figure(
|
138 |
+
data=[edge_trace, node_trace],
|
139 |
+
layout=go.Layout(
|
140 |
+
showlegend=False,
|
141 |
+
hovermode='closest',
|
142 |
+
margin=dict(b=0,l=0,r=0,t=0),
|
143 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
144 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
145 |
+
plot_bgcolor='rgba(0,0,0,0)',
|
146 |
+
paper_bgcolor='rgba(0,0,0,0)'
|
147 |
+
)
|
148 |
+
)
|
149 |
+
|
150 |
+
return fig
|
151 |
+
|
152 |
# Streamlit UI
|
153 |
def main():
|
154 |
st.title("Nexus NLP News Classifier")
|
|
|
203 |
df = pd.DataFrame(entities, columns=["Entity", "Type"])
|
204 |
st.dataframe(df)
|
205 |
|
206 |
+
# Generate and display knowledge graph
|
207 |
+
fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model)
|
208 |
+
st.plotly_chart(fig, use_container_width=True)
|
209 |
+
|
210 |
else:
|
211 |
st.warning("Please enter some text to analyze")
|
212 |
|
final.py
CHANGED
@@ -278,6 +278,7 @@ import google.generativeai as genai
|
|
278 |
import json
|
279 |
import os
|
280 |
import dotenv
|
|
|
281 |
|
282 |
# Load environment variables
|
283 |
dotenv.load_dotenv()
|
@@ -303,6 +304,40 @@ def load_knowledge_graph():
|
|
303 |
knowledge_graph.add_edge(u, v, **data)
|
304 |
return knowledge_graph
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
def setup_gemini():
|
307 |
"""Initialize Gemini model"""
|
308 |
genai.configure(api_key=os.getenv("GEMINI_API"))
|
|
|
278 |
import json
|
279 |
import os
|
280 |
import dotenv
|
281 |
+
import plotly.graph_objects as go
|
282 |
|
283 |
# Load environment variables
|
284 |
dotenv.load_dotenv()
|
|
|
304 |
knowledge_graph.add_edge(u, v, **data)
|
305 |
return knowledge_graph
|
306 |
|
307 |
+
|
308 |
+
class KnowledgeGraphBuilder:
|
309 |
+
def __init__(self):
|
310 |
+
self.knowledge_graph = nx.DiGraph()
|
311 |
+
|
312 |
+
def update_knowledge_graph(self, text, is_real, nlp):
|
313 |
+
entities = extract_entities(text, nlp)
|
314 |
+
for entity, entity_type in entities:
|
315 |
+
if not self.knowledge_graph.has_node(entity):
|
316 |
+
self.knowledge_graph.add_node(
|
317 |
+
entity,
|
318 |
+
type=entity_type,
|
319 |
+
real_count=1 if is_real else 0,
|
320 |
+
fake_count=0 if is_real else 1
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
if is_real:
|
324 |
+
self.knowledge_graph.nodes[entity]['real_count'] += 1
|
325 |
+
else:
|
326 |
+
self.knowledge_graph.nodes[entity]['fake_count'] += 1
|
327 |
+
|
328 |
+
for i, (entity1, _) in enumerate(entities):
|
329 |
+
for entity2, _ in entities[i+1:]:
|
330 |
+
if not self.knowledge_graph.has_edge(entity1, entity2):
|
331 |
+
self.knowledge_graph.add_edge(
|
332 |
+
entity1,
|
333 |
+
entity2,
|
334 |
+
weight=1,
|
335 |
+
is_real=is_real
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
self.knowledge_graph[entity1][entity2]['weight'] += 1
|
339 |
+
|
340 |
+
|
341 |
def setup_gemini():
|
342 |
"""Initialize Gemini model"""
|
343 |
genai.configure(api_key=os.getenv("GEMINI_API"))
|
requirements.txt
CHANGED
@@ -11,4 +11,6 @@ uvicorn
|
|
11 |
tiktoken
|
12 |
sentencepiece
|
13 |
timm
|
|
|
|
|
14 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.0/en_core_web_sm-3.7.0.tar.gz
|
|
|
11 |
tiktoken
|
12 |
sentencepiece
|
13 |
timm
|
14 |
+
plotly
|
15 |
+
networkx
|
16 |
https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.0/en_core_web_sm-3.7.0.tar.gz
|