|
from smolagents import CodeAgent, load_tool, tool, LiteLLMModel |
|
import datetime |
|
import pytz |
|
import yaml |
|
import os |
|
import tempfile |
|
import re |
|
from typing import Dict, List, Optional, Union, Any |
|
from smolagents import GradioUI |
|
|
|
|
|
from newsapi import NewsApiClient |
|
|
|
|
|
|
|
from tools.final_answer import FinalAnswerTool |
|
from tools.visit_webpage import VisitWebpageTool |
|
from tools.web_search import DuckDuckGoSearchTool |
|
|
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
def create_document(text: str, format: str = "docx") -> str: |
|
"""Creates a document with the provided text and allows download. |
|
Args: |
|
text: The text content to write to the document. |
|
format: The output format, either 'docx', 'pdf', or 'txt'. Default is 'docx'. |
|
""" |
|
try: |
|
temp_dir = tempfile.mkdtemp() |
|
file_name = "generated_document" |
|
|
|
if format.lower() == "txt": |
|
path = os.path.join(temp_dir, f"{file_name}.txt") |
|
with open(path, "w", encoding="utf-8") as f: |
|
f.write(text) |
|
print(f"Document created (txt): {path}") |
|
return path |
|
|
|
elif format.lower() in ["docx", "pdf"]: |
|
try: |
|
import docx |
|
from docx.shared import Pt |
|
except ImportError: |
|
return (f"ERROR: To create DOCX or PDF files, the 'python-docx' package is required. " |
|
f"Please install it (e.g., 'pip install python-docx'). " |
|
f"You can try creating a 'txt' file instead by specifying format='txt'.") |
|
|
|
doc = docx.Document() |
|
doc.add_heading('Generated Document', 0) |
|
style = doc.styles['Normal'] |
|
font = style.font |
|
font.name = 'Calibri' |
|
font.size = Pt(11) |
|
|
|
for paragraph in text.split('\n'): |
|
if paragraph.strip(): |
|
doc.add_paragraph(paragraph) |
|
|
|
docx_path = os.path.join(temp_dir, f"{file_name}.docx") |
|
doc.save(docx_path) |
|
print(f"Document created (docx): {docx_path}") |
|
|
|
if format.lower() == "pdf": |
|
try: |
|
from docx2pdf import convert |
|
pdf_path = os.path.join(temp_dir, f"{file_name}.pdf") |
|
convert(docx_path, pdf_path) |
|
print(f"Document converted to PDF: {pdf_path}") |
|
return pdf_path |
|
except ImportError: |
|
err_msg = (f"ERROR: PDF conversion requires the 'docx2pdf' package. " |
|
f"Please install it (e.g., 'pip install docx2pdf'). " |
|
f"Document saved as DOCX instead at: {docx_path}") |
|
print(err_msg) |
|
return err_msg |
|
except Exception as e_pdf: |
|
err_msg = f"Error converting DOCX to PDF: {str(e_pdf)}. Document saved as DOCX at: {docx_path}" |
|
print(err_msg) |
|
return err_msg |
|
return docx_path |
|
else: |
|
return f"Error: Unsupported format '{format}'. Supported formats are 'docx', 'pdf', 'txt'." |
|
except Exception as e: |
|
print(f"General error in create_document: {str(e)}") |
|
return f"Error creating document: {str(e)}" |
|
|
|
@tool |
|
def get_file_download_link(file_path: str) -> str: |
|
"""Informs that a file is ready for download and its type. (Used by agent to tell user). |
|
Args: |
|
file_path: Path to the file that should be made available for download. |
|
""" |
|
if not os.path.exists(file_path): |
|
print(f"get_file_download_link: File not found at {file_path}") |
|
return f"Error: File not found at {file_path}" |
|
|
|
_, file_extension = os.path.splitext(file_path) |
|
mime_types = { |
|
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', |
|
'.pdf': 'application/pdf', |
|
'.txt': 'text/plain', |
|
} |
|
mime_type = mime_types.get(file_extension.lower(), 'application/octet-stream') |
|
msg = f"File '{os.path.basename(file_path)}' is ready for download (type: {mime_type})." |
|
print(f"get_file_download_link: {msg}") |
|
return msg |
|
|
|
@tool |
|
def get_latest_news(query: Optional[str] = None, |
|
category: Optional[str] = None, |
|
country: Optional[str] = "us", |
|
language: Optional[str] = "en", |
|
page_size: int = 5) -> str: |
|
""" |
|
Fetches the latest news headlines. |
|
You can specify a query, category, country, and language. |
|
|
|
Args: |
|
query: Keywords or a phrase to search for in the news articles. (e.g., "Tesla stock") |
|
category: The category: business, entertainment, general, health, science, sports, technology. |
|
country: The 2-letter ISO 3166-1 code of the country (e.g., 'us', 'gb'). Default 'us'. |
|
language: The 2-letter ISO 639-1 code of the language (e.g., 'en', 'es'). Default 'en'. |
|
page_size: Number of results (max 100, default 5). |
|
""" |
|
api_key = os.getenv("NEWS_API_KEY") |
|
if not api_key: |
|
return "ERROR: News API key (NEWS_API_KEY) is not set in environment variables." |
|
|
|
newsapi = NewsApiClient(api_key=api_key) |
|
print(f"DEBUG NewsTool: query='{query}', category='{category}', country='{country}', lang='{language}', size={page_size}") |
|
|
|
try: |
|
if query: |
|
print(f"Fetching news with query: '{query}'") |
|
top_headlines = newsapi.get_everything(q=query, |
|
language=language, |
|
sort_by='publishedAt', |
|
page_size=page_size) |
|
elif category or country != "us": |
|
print(f"Fetching top headlines for category: '{category}', country: '{country}'") |
|
top_headlines = newsapi.get_top_headlines(q=None, |
|
category=category, |
|
language=language, |
|
country=country, |
|
page_size=page_size) |
|
else: |
|
print(f"Fetching default top headlines for country: '{country}', language: '{language}'") |
|
top_headlines = newsapi.get_top_headlines(language=language, |
|
country=country, |
|
page_size=page_size) |
|
|
|
if top_headlines['status'] == 'ok': |
|
articles = top_headlines['articles'] |
|
if not articles: |
|
return "No news articles found for the given criteria." |
|
|
|
formatted_news = "## Latest News:\n\n" |
|
for i, article in enumerate(articles[:page_size]): |
|
title = article.get('title', 'N/A') |
|
source_name = article.get('source', {}).get('name', 'N/A') |
|
description = article.get('description', 'No description available.') |
|
url = article.get('url', '#') |
|
published_at_str = article.get('publishedAt', 'N/A') |
|
|
|
published_at_formatted = published_at_str |
|
try: |
|
if published_at_str and published_at_str != 'N/A': |
|
|
|
dt_object = datetime.datetime.fromisoformat(published_at_str.replace('Z', '+00:00')) |
|
|
|
published_at_formatted = dt_object.strftime('%Y-%m-%d %H:%M') + " UTC" |
|
except ValueError: |
|
pass |
|
|
|
formatted_news += ( |
|
f"{i+1}. **{title}**\n" |
|
f" - Source: {source_name}\n" |
|
f" - Published: {published_at_formatted}\n" |
|
f" - Description: {description[:200] + '...' if description and len(description) > 200 else (description or '')}\n" |
|
f" - URL: {url}\n\n" |
|
) |
|
return formatted_news.strip() |
|
else: |
|
err_msg = f"ERROR: Could not fetch news. API Response: {top_headlines.get('code')} - {top_headlines.get('message')}" |
|
print(err_msg) |
|
return err_msg |
|
|
|
except Exception as e: |
|
print(f"ERROR: Exception in get_latest_news: {str(e)}") |
|
return f"ERROR: An exception occurred while fetching news: {str(e)}" |
|
|
|
|
|
final_answer = FinalAnswerTool() |
|
web_search = DuckDuckGoSearchTool() |
|
visit_webpage = VisitWebpageTool() |
|
|
|
try: |
|
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True) |
|
print(f"Successfully loaded image generation tool. Name: {getattr(image_generation_tool, 'name', 'N/A')}") |
|
except Exception as e: |
|
print(f"Error loading image generation tool: {e}. Image generation will not be available.") |
|
image_generation_tool = None |
|
|
|
|
|
if not os.getenv("GEMINI_KEY"): |
|
print("CRITICAL: GEMINI_KEY environment variable not set. LLM will not work.") |
|
|
|
|
|
model = LiteLLMModel( |
|
model_id="gemini/gemini-1.5-flash-latest", |
|
api_key=os.getenv("GEMINI_KEY"), |
|
max_tokens=4096, |
|
temperature=0.6 |
|
) |
|
|
|
|
|
prompts_file = "prompts.yaml" |
|
if not os.path.exists(prompts_file): |
|
print(f"Warning: '{prompts_file}' not found. Using default agent prompts.") |
|
prompt_templates = None |
|
else: |
|
with open(prompts_file, 'r') as stream: |
|
prompt_templates = yaml.safe_load(stream) |
|
|
|
|
|
agent_tools = [ |
|
final_answer, |
|
web_search, |
|
visit_webpage, |
|
create_document, |
|
get_file_download_link, |
|
get_latest_news, |
|
] |
|
|
|
if image_generation_tool: |
|
agent_tools.append(image_generation_tool) |
|
else: |
|
print("Note: Image generation tool was not loaded, so it's not added to the agent.") |
|
|
|
agent = CodeAgent( |
|
model=model, |
|
tools=agent_tools, |
|
max_steps=10, |
|
verbosity_level=1, |
|
prompt_templates=prompt_templates |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.getenv("NEWS_API_KEY"): |
|
print("Warning: NEWS_API_KEY environment variable not set. News tool will return an error.") |
|
|
|
print("Starting Gradio UI...") |
|
|
|
ui = GradioUI(agent) |
|
ui.launch() |