Krish Patel commited on
Commit
5e58061
·
1 Parent(s): c422cb8

Add ploty for knowledge graph

Browse files
Files changed (3) hide show
  1. app.py +75 -0
  2. final.py +35 -0
  3. requirements.txt +2 -0
app.py CHANGED
@@ -50,6 +50,8 @@ st.set_page_config(
50
 
51
  import pandas as pd
52
  from final import *
 
 
53
 
54
  # # Cache model loading
55
  # @st.cache_resource
@@ -78,6 +80,75 @@ def initialize_models():
78
  knowledge_graph = load_knowledge_graph()
79
  return nlp, tokenizer, model, knowledge_graph
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Streamlit UI
82
  def main():
83
  st.title("Nexus NLP News Classifier")
@@ -132,6 +203,10 @@ def main():
132
  df = pd.DataFrame(entities, columns=["Entity", "Type"])
133
  st.dataframe(df)
134
 
 
 
 
 
135
  else:
136
  st.warning("Please enter some text to analyze")
137
 
 
50
 
51
  import pandas as pd
52
  from final import *
53
+ from pydantic import BaseModel
54
+ import plotly.graph_objects as go
55
 
56
  # # Cache model loading
57
  # @st.cache_resource
 
80
  knowledge_graph = load_knowledge_graph()
81
  return nlp, tokenizer, model, knowledge_graph
82
 
83
+
84
+ class NewsInput(BaseModel):
85
+ text: str
86
+
87
+ def generate_knowledge_graph_viz(text, nlp, tokenizer, model):
88
+ kg_builder = KnowledgeGraphBuilder()
89
+
90
+ # Get prediction
91
+ prediction, _ = predict_with_model(text, tokenizer, model)
92
+ is_fake = prediction == "FAKE"
93
+
94
+ # Update knowledge graph
95
+ kg_builder.update_knowledge_graph(text, not is_fake, nlp)
96
+
97
+ # Create visualization
98
+ pos = nx.spring_layout(kg_builder.knowledge_graph)
99
+
100
+ edge_trace = go.Scatter(
101
+ x=[], y=[],
102
+ line=dict(
103
+ width=2,
104
+ color='rgba(255,0,0,0.7)' if is_fake else 'rgba(0,255,0,0.7)'
105
+ ),
106
+ hoverinfo='none',
107
+ mode='lines'
108
+ )
109
+
110
+ node_trace = go.Scatter(
111
+ x=[], y=[],
112
+ mode='markers+text',
113
+ hoverinfo='text',
114
+ textposition='top center',
115
+ marker=dict(
116
+ size=15,
117
+ color='white',
118
+ line=dict(width=2, color='black')
119
+ ),
120
+ text=[]
121
+ )
122
+
123
+ # Add edges
124
+ for edge in kg_builder.knowledge_graph.edges():
125
+ x0, y0 = pos[edge[0]]
126
+ x1, y1 = pos[edge[1]]
127
+ edge_trace['x'] += (x0, x1, None)
128
+ edge_trace['y'] += (y0, y1, None)
129
+
130
+ # Add nodes
131
+ for node in kg_builder.knowledge_graph.nodes():
132
+ x, y = pos[node]
133
+ node_trace['x'] += (x,)
134
+ node_trace['y'] += (y,)
135
+ node_trace['text'] += (node,)
136
+
137
+ fig = go.Figure(
138
+ data=[edge_trace, node_trace],
139
+ layout=go.Layout(
140
+ showlegend=False,
141
+ hovermode='closest',
142
+ margin=dict(b=0,l=0,r=0,t=0),
143
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
144
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
145
+ plot_bgcolor='rgba(0,0,0,0)',
146
+ paper_bgcolor='rgba(0,0,0,0)'
147
+ )
148
+ )
149
+
150
+ return fig
151
+
152
  # Streamlit UI
153
  def main():
154
  st.title("Nexus NLP News Classifier")
 
203
  df = pd.DataFrame(entities, columns=["Entity", "Type"])
204
  st.dataframe(df)
205
 
206
+ # Generate and display knowledge graph
207
+ fig = generate_knowledge_graph_viz(news_text, nlp, tokenizer, model)
208
+ st.plotly_chart(fig, use_container_width=True)
209
+
210
  else:
211
  st.warning("Please enter some text to analyze")
212
 
final.py CHANGED
@@ -278,6 +278,7 @@ import google.generativeai as genai
278
  import json
279
  import os
280
  import dotenv
 
281
 
282
  # Load environment variables
283
  dotenv.load_dotenv()
@@ -303,6 +304,40 @@ def load_knowledge_graph():
303
  knowledge_graph.add_edge(u, v, **data)
304
  return knowledge_graph
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  def setup_gemini():
307
  """Initialize Gemini model"""
308
  genai.configure(api_key=os.getenv("GEMINI_API"))
 
278
  import json
279
  import os
280
  import dotenv
281
+ import plotly.graph_objects as go
282
 
283
  # Load environment variables
284
  dotenv.load_dotenv()
 
304
  knowledge_graph.add_edge(u, v, **data)
305
  return knowledge_graph
306
 
307
+
308
+ class KnowledgeGraphBuilder:
309
+ def __init__(self):
310
+ self.knowledge_graph = nx.DiGraph()
311
+
312
+ def update_knowledge_graph(self, text, is_real, nlp):
313
+ entities = extract_entities(text, nlp)
314
+ for entity, entity_type in entities:
315
+ if not self.knowledge_graph.has_node(entity):
316
+ self.knowledge_graph.add_node(
317
+ entity,
318
+ type=entity_type,
319
+ real_count=1 if is_real else 0,
320
+ fake_count=0 if is_real else 1
321
+ )
322
+ else:
323
+ if is_real:
324
+ self.knowledge_graph.nodes[entity]['real_count'] += 1
325
+ else:
326
+ self.knowledge_graph.nodes[entity]['fake_count'] += 1
327
+
328
+ for i, (entity1, _) in enumerate(entities):
329
+ for entity2, _ in entities[i+1:]:
330
+ if not self.knowledge_graph.has_edge(entity1, entity2):
331
+ self.knowledge_graph.add_edge(
332
+ entity1,
333
+ entity2,
334
+ weight=1,
335
+ is_real=is_real
336
+ )
337
+ else:
338
+ self.knowledge_graph[entity1][entity2]['weight'] += 1
339
+
340
+
341
  def setup_gemini():
342
  """Initialize Gemini model"""
343
  genai.configure(api_key=os.getenv("GEMINI_API"))
requirements.txt CHANGED
@@ -11,4 +11,6 @@ uvicorn
11
  tiktoken
12
  sentencepiece
13
  timm
 
 
14
  https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.0/en_core_web_sm-3.7.0.tar.gz
 
11
  tiktoken
12
  sentencepiece
13
  timm
14
+ plotly
15
+ networkx
16
  https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.0/en_core_web_sm-3.7.0.tar.gz