jarvis_gaia_agent / tools /answer_generator.py
onisj's picture
feat(tools): add more tool to extend the functionaily of jarvis
751d628
import nltk
import logging
import numpy as np
from typing import List, Any
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage
from sentence_transformers import SentenceTransformer
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
logger = logging.getLogger(__name__)
# Download NLTK data
try:
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
except Exception as e:
logger.warning(f"NLTK data download failed: {e}")
# Global embedder
_embedder = None
def get_embedder():
global _embedder
if _embedder is None:
try:
_embedder = SentenceTransformer(
"all-MiniLM-L6-v2",
device="cpu",
cache_folder="./cache"
)
logger.info("SentenceTransformer initialized")
except Exception as e:
logger.error(f"Failed to initialize SentenceTransformer: {e}")
raise RuntimeError(f"Embedder initialization failed: {e}")
return _embedder
def filter_results(search_results: List[str], question: str) -> List[str]:
try:
if not search_results or not question:
return search_results
embedder = get_embedder()
question_embedding = embedder.encode([question], convert_to_numpy=True)
result_embeddings = embedder.encode(search_results, convert_to_numpy=True)
similarities = np.dot(result_embeddings, question_embedding.T).flatten()
filtered_results = [
search_results[i] for i in range(len(search_results))
if similarities[i] > 0.5 and search_results[i].strip()
]
return filtered_results if filtered_results else search_results[:3]
except Exception as e:
logger.warning(f"Result filtering failed: {e}")
return search_results[:3]
async def preprocess_question(question: str) -> str:
"""Preprocess the question to clean and standardize it."""
try:
question = question.strip().lower()
if not question.endswith("?"):
question += "?"
logger.debug(f"Preprocessed question: {question}")
return question
except Exception as e:
logger.error(f"Error preprocessing question: {e}")
return question
async def generate_answer(
task_id: str,
question: str,
search_results: List[str],
file_results: str,
llm_client: Any
) -> str:
"""Generate an answer using LLM with search and file results."""
try:
if not search_results:
search_results = ["No search results available."]
if not file_results:
file_results = "No file results available."
context = "\n".join([str(r) for r in search_results]) + "\n" + file_results
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""You are an assistant answering questions using provided context.
- Use ONLY the context to formulate a concise, accurate answer.
- If the context is insufficient, state: 'Insufficient information to answer.'
- Do NOT generate or assume information beyond the context.
- Return a single, clear sentence or phrase as the answer."""),
HumanMessage(content=f"Context: {context}\nQuestion: {question}")
])
messages = [
{"role": "system", "content": prompt[0].content},
{"role": "user", "content": prompt[1].content}
]
if isinstance(llm_client, tuple): # hf_local
model, tokenizer = llm_client
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
elif hasattr(llm_client, "chat"): # together
response = llm_client.chat.completions.create(
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
messages=messages,
max_tokens=100,
temperature=0.7,
top_p=0.9,
frequency_penalty=0.5
)
response = response.choices[0].message.content.strip()
else: # hf_api
response = llm_client.chat.completions.create(
messages=messages,
max_tokens=100,
temperature=0.7
)
response = response.choices[0].message.content.strip()
answer = response.strip()
if not answer or answer.lower() == "none":
answer = "Insufficient information to answer."
logger.info(f"Task {task_id}: Generated answer: {answer}")
return answer
except Exception as e:
logger.error(f"Task {task_id}: Answer generation failed: {e}")
return "Error generating answer."