Spaces:
Running
Running
import gradio as gr | |
import os | |
import json | |
import requests | |
from bs4 import BeautifulSoup | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import io | |
import base64 | |
from huggingface_hub import InferenceClient | |
import re | |
from urllib.parse import urlparse | |
import warnings | |
# Configure matplotlib for better font handling | |
plt.rcParams['font.family'] = ['DejaVu Sans'] | |
plt.rcParams['font.size'] = 10 | |
plt.rcParams['font.weight'] = 'normal' | |
warnings.filterwarnings('ignore', category=UserWarning) | |
warnings.filterwarnings('ignore', message='.*Font family.*not found.*') | |
def clean_text_for_display(text): | |
"""Clean text to remove characters that might cause font issues.""" | |
if not isinstance(text, str): | |
return str(text) | |
# Remove or replace problematic characters | |
text = re.sub(r'[^\x00-\x7F]+', '', text) # Remove non-ASCII characters | |
text = re.sub(r'\s+', ' ', text).strip() # Normalize whitespace | |
return text[:50] if len(text) > 50 else text # Limit length for display | |
def fetch_content(url_or_text): | |
"""Fetch content from URL or return text directly. | |
Args: | |
url_or_text: Either a URL to fetch content from, or direct text input | |
Returns: | |
Extracted text content | |
""" | |
try: | |
# Check if input looks like a URL | |
parsed = urlparse(url_or_text) | |
if parsed.scheme in ['http', 'https']: | |
try: | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
} | |
response = requests.get(url_or_text, headers=headers, timeout=10) | |
response.raise_for_status() | |
# Parse HTML and extract text | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Remove script and style elements | |
for script in soup(["script", "style"]): | |
script.decompose() | |
# Get text and clean it up | |
text = soup.get_text() | |
lines = (line.strip() for line in text.splitlines()) | |
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) | |
text = ' '.join(chunk for chunk in chunks if chunk) | |
return text[:5000] # Limit to first 5000 characters | |
except Exception as e: | |
return f"Error fetching URL: {str(e)}" | |
else: | |
# It's direct text input | |
return url_or_text | |
except Exception as e: | |
return f"Error processing input: {str(e)}" | |
def simple_entity_extraction(text): | |
"""Fallback entity extraction when AI is not available.""" | |
try: | |
words = text.split() | |
entities = [] | |
# Simple heuristic: words that are capitalized and longer than 2 characters | |
seen = set() | |
for word in words[:30]: # Limit to first 30 words | |
clean_word = re.sub(r'[^\w]', '', word) | |
if (clean_word.istitle() and len(clean_word) > 2 and | |
clean_word.lower() not in seen and | |
clean_word not in ['The', 'This', 'That', 'When', 'Where', 'How']): | |
entities.append({ | |
"name": clean_text_for_display(clean_word), | |
"type": "CONCEPT", | |
"description": "Auto-detected entity" | |
}) | |
seen.add(clean_word.lower()) | |
# Create some basic relationships | |
relationships = [] | |
if len(entities) > 1: | |
for i in range(min(len(entities) - 1, 5)): # Max 5 relationships | |
relationships.append({ | |
"source": entities[i]["name"], | |
"target": entities[i + 1]["name"], | |
"relation": "related_to", | |
"description": "Sequential relationship" | |
}) | |
return {"entities": entities[:10], "relationships": relationships} | |
except Exception as e: | |
return { | |
"entities": [{"name": "Error", "type": "ERROR", "description": str(e)}], | |
"relationships": [] | |
} | |
def extract_entities(text): | |
"""Extract entities and relationships using Mistral AI with fallback. | |
Args: | |
text: Input text to analyze | |
Returns: | |
Dictionary containing entities and relationships | |
""" | |
try: | |
# Check if HF_TOKEN is available | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
print("No HF_TOKEN found, using simple extraction") | |
return simple_entity_extraction(text) | |
client = InferenceClient( | |
provider="together", | |
api_key=hf_token, | |
) | |
prompt = f""" | |
Analyze the following text and extract: | |
1. Named entities (people, organizations, locations, concepts) | |
2. Relationships between these entities | |
Return ONLY a valid JSON object with this structure: | |
{{ | |
"entities": [ | |
{{"name": "entity_name", "type": "PERSON", "description": "brief description"}} | |
], | |
"relationships": [ | |
{{"source": "entity1", "target": "entity2", "relation": "relationship_type", "description": "brief description"}} | |
] | |
}} | |
Text to analyze: {text[:1500]} | |
""" | |
completion = client.chat.completions.create( | |
model="mistralai/Mistral-Small-24B-Instruct-2501", | |
messages=[{"role": "user", "content": prompt}], | |
max_tokens=1000, | |
temperature=0.1, | |
) | |
response_text = completion.choices[0].message.content | |
# Clean and extract JSON | |
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) | |
if json_match: | |
json_str = json_match.group() | |
# Clean the JSON string | |
json_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', json_str) # Remove control characters | |
parsed_data = json.loads(json_str) | |
# Clean entity names for display | |
if "entities" in parsed_data: | |
for entity in parsed_data["entities"]: | |
if "name" in entity: | |
entity["name"] = clean_text_for_display(entity["name"]) | |
return parsed_data | |
else: | |
print("No valid JSON found in AI response, using fallback") | |
return simple_entity_extraction(text) | |
except Exception as e: | |
print(f"AI extraction failed: {e}, using fallback") | |
return simple_entity_extraction(text) | |
def build_knowledge_graph(entities_data): | |
"""Build and visualize knowledge graph. | |
Args: | |
entities_data: Dictionary containing entities and relationships | |
Returns: | |
PIL Image object of the knowledge graph | |
""" | |
try: | |
# Create networkx graph | |
G = nx.Graph() | |
# Add nodes (entities) | |
entities = entities_data.get("entities", []) | |
for entity in entities[:15]: # Limit to 15 entities for better visualization | |
clean_name = clean_text_for_display(entity.get("name", "Unknown")) | |
if clean_name and len(clean_name.strip()) > 0: | |
G.add_node(clean_name, | |
type=entity.get("type", "UNKNOWN"), | |
description=entity.get("description", "")) | |
# Add edges (relationships) | |
relationships = entities_data.get("relationships", []) | |
for rel in relationships: | |
source = clean_text_for_display(rel.get("source", "")) | |
target = clean_text_for_display(rel.get("target", "")) | |
if source in G.nodes and target in G.nodes: | |
G.add_edge(source, target, | |
relation=rel.get("relation", "related"), | |
description=rel.get("description", "")) | |
# If no relationships found, create some connections between entities | |
if len(relationships) == 0 and len(list(G.nodes())) > 1: | |
node_list = list(G.nodes()) | |
for i in range(min(len(node_list) - 1, 5)): | |
G.add_edge(node_list[i], node_list[i + 1], relation="related") | |
# Create visualization | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
# Skip if no nodes | |
if len(G.nodes()) == 0: | |
ax.text(0.5, 0.5, "No entities found to visualize", | |
ha='center', va='center', fontsize=14, transform=ax.transAxes) | |
ax.set_title("Knowledge Graph") | |
ax.axis('off') | |
else: | |
# Position nodes using spring layout | |
pos = nx.spring_layout(G, k=1, iterations=50) | |
# Color nodes by type | |
node_colors = [] | |
type_colors = { | |
"PERSON": "#FF6B6B", | |
"ORG": "#4ECDC4", | |
"LOCATION": "#45B7D1", | |
"CONCEPT": "#96CEB4", | |
"ERROR": "#FF0000", | |
"UNKNOWN": "#DDA0DD" | |
} | |
for node in G.nodes(): | |
node_type = G.nodes[node].get('type', 'UNKNOWN') | |
node_colors.append(type_colors.get(node_type, "#DDA0DD")) | |
# Draw the graph | |
nx.draw(G, pos, | |
node_color=node_colors, | |
node_size=800, | |
font_size=8, | |
font_weight='bold', | |
with_labels=True, | |
edge_color='gray', | |
width=1.5, | |
alpha=0.8, | |
ax=ax) | |
# Add title | |
ax.set_title("Knowledge Graph", size=14, weight='bold') | |
# Convert to PIL Image | |
fig.canvas.draw() | |
# Handle different matplotlib versions | |
try: | |
# Try newer method first | |
img_array = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
# Convert RGBA to RGB | |
img_array = img_array[:, :, :3] | |
except AttributeError: | |
try: | |
# Fallback to older method | |
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
except AttributeError: | |
# Final fallback - save to buffer | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
from PIL import Image | |
pil_image = Image.open(buf).convert('RGB') | |
plt.close(fig) | |
return pil_image | |
from PIL import Image | |
pil_image = Image.fromarray(img_array) | |
plt.close(fig) | |
return pil_image | |
except Exception as e: | |
# Create simple error image | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
ax.text(0.5, 0.5, f"Error creating graph", | |
ha='center', va='center', fontsize=12, transform=ax.transAxes) | |
ax.set_title("Knowledge Graph Error") | |
ax.axis('off') | |
# Handle different matplotlib versions for error image | |
try: | |
# Try newer method first | |
fig.canvas.draw() | |
img_array = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
img_array = img_array[:, :, :3] # Convert RGBA to RGB | |
except AttributeError: | |
try: | |
# Fallback to older method | |
fig.canvas.draw() | |
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
except AttributeError: | |
# Final fallback - save to buffer | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png', bbox_inches='tight') | |
buf.seek(0) | |
from PIL import Image | |
pil_image = Image.open(buf).convert('RGB') | |
plt.close(fig) | |
return pil_image | |
from PIL import Image | |
pil_image = Image.fromarray(img_array) | |
plt.close(fig) | |
return pil_image | |
def knowledge_graph_builder(url_or_text): | |
"""Main function to build knowledge graph from URL or text. | |
Args: | |
url_or_text: URL to analyze or direct text input | |
Returns: | |
Tuple of (entities_json, graph_image, summary) | |
""" | |
try: | |
if not url_or_text or len(url_or_text.strip()) == 0: | |
return "{}", None, "Please provide some text or a URL to analyze." | |
# Step 1: Fetch content | |
content = fetch_content(url_or_text) | |
if content.startswith("Error"): | |
return json.dumps({"error": content}), None, content | |
# Step 2: Extract entities | |
entities_data = extract_entities(content) | |
# Step 3: Build knowledge graph | |
graph_image = build_knowledge_graph(entities_data) | |
# Step 4: Create summary | |
num_entities = len(entities_data.get("entities", [])) | |
num_relationships = len(entities_data.get("relationships", [])) | |
summary = f"""## Knowledge Graph Analysis Complete! | |
π **Statistics:** | |
- Entities found: {num_entities} | |
- Relationships found: {num_relationships} | |
- Content length: {len(content)} characters | |
π **Extracted Entities:**""" | |
for entity in entities_data.get("entities", [])[:8]: # Show first 8 | |
name = entity.get('name', 'Unknown') | |
entity_type = entity.get('type', 'UNKNOWN') | |
desc = entity.get('description', 'No description') | |
summary += f"\nβ’ **{name}** ({entity_type}): {desc}" | |
if len(entities_data.get("entities", [])) > 8: | |
summary += f"\n\n... and {len(entities_data.get('entities', [])) - 8} more entities" | |
# Ensure valid JSON output | |
try: | |
json_output = json.dumps(entities_data, indent=2, ensure_ascii=True) | |
except Exception as e: | |
json_output = json.dumps({"error": f"JSON serialization failed: {str(e)}"}) | |
return json_output, graph_image, summary | |
except Exception as e: | |
error_msg = f"Analysis failed: {str(e)}" | |
return json.dumps({"error": error_msg}), None, error_msg | |
# Create Gradio interface with error handling | |
try: | |
demo = gr.Interface( | |
fn=knowledge_graph_builder, | |
inputs=[ | |
gr.Textbox( | |
label="URL or Text Input", | |
placeholder="Enter a URL (https://example.com) or paste text directly...", | |
lines=3, | |
info="Enter a website URL to analyze, or paste text content directly" | |
) | |
], | |
outputs=[ | |
gr.JSON(label="Extracted Entities & Relationships"), | |
gr.Image(label="Knowledge Graph Visualization", type="pil"), | |
gr.Markdown(label="Analysis Summary") | |
], | |
title="π§ AI Knowledge Graph Builder", | |
description=""" | |
**Transform any text or webpage into an interactive knowledge graph!** | |
This tool: | |
1. π Extracts content from URLs or analyzes your text | |
2. π€ Uses AI to identify entities and relationships | |
3. πΈοΈ Builds and visualizes knowledge graphs | |
4. π Provides detailed analysis summaries | |
**Try with:** news articles, Wikipedia pages, or any text content | |
""", | |
theme=gr.themes.Soft(), | |
allow_flagging="never", | |
cache_examples=False # Disable example caching to prevent startup errors | |
) | |
except Exception as e: | |
print(f"Failed to create Gradio interface: {e}") | |
# Create a simple fallback interface | |
def simple_demo(text): | |
return json.dumps({"error": f"Startup failed: {str(e)}"}), None, "Application failed to start properly." | |
demo = gr.Interface( | |
fn=simple_demo, | |
inputs=[gr.Textbox(label="Input", placeholder="Enter text...")], | |
outputs=[ | |
gr.JSON(label="Error Output"), | |
gr.Image(label="No Image"), | |
gr.Markdown(label="Error Message") | |
], | |
title="β οΈ Knowledge Graph Builder - Startup Error", | |
allow_flagging="never", | |
cache_examples=False | |
) | |
# Launch the demo | |
if __name__ == "__main__": | |
try: | |
demo.launch( | |
mcp_server=True, | |
share=False, | |
show_error=True, | |
quiet=False | |
) | |
except Exception as e: | |
print(f"Launch failed: {e}") | |
# Try without MCP server as fallback | |
try: | |
demo.launch( | |
mcp_server=False, | |
share=False, | |
show_error=True | |
) | |
except Exception as e2: | |
print(f"Complete failure: {e2}") | |