Spaces:
Starting
Starting
import nltk | |
import logging | |
import numpy as np | |
from typing import List, Any | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from sentence_transformers import SentenceTransformer | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Download NLTK data | |
try: | |
nltk.download('punkt', quiet=True) | |
nltk.download('stopwords', quiet=True) | |
except Exception as e: | |
logger.warning(f"NLTK data download failed: {e}") | |
# Global embedder | |
_embedder = None | |
def get_embedder(): | |
global _embedder | |
if _embedder is None: | |
try: | |
_embedder = SentenceTransformer( | |
"all-MiniLM-L6-v2", | |
device="cpu", | |
cache_folder="./cache" | |
) | |
logger.info("SentenceTransformer initialized") | |
except Exception as e: | |
logger.error(f"Failed to initialize SentenceTransformer: {e}") | |
raise RuntimeError(f"Embedder initialization failed: {e}") | |
return _embedder | |
def filter_results(search_results: List[str], question: str) -> List[str]: | |
try: | |
if not search_results or not question: | |
return search_results | |
embedder = get_embedder() | |
question_embedding = embedder.encode([question], convert_to_numpy=True) | |
result_embeddings = embedder.encode(search_results, convert_to_numpy=True) | |
similarities = np.dot(result_embeddings, question_embedding.T).flatten() | |
filtered_results = [ | |
search_results[i] for i in range(len(search_results)) | |
if similarities[i] > 0.5 and search_results[i].strip() | |
] | |
return filtered_results if filtered_results else search_results[:3] | |
except Exception as e: | |
logger.warning(f"Result filtering failed: {e}") | |
return search_results[:3] | |
async def preprocess_question(question: str) -> str: | |
"""Preprocess the question to clean and standardize it.""" | |
try: | |
question = question.strip().lower() | |
if not question.endswith("?"): | |
question += "?" | |
logger.debug(f"Preprocessed question: {question}") | |
return question | |
except Exception as e: | |
logger.error(f"Error preprocessing question: {e}") | |
return question | |
async def generate_answer( | |
task_id: str, | |
question: str, | |
search_results: List[str], | |
file_results: str, | |
llm_client: Any | |
) -> str: | |
"""Generate an answer using LLM with search and file results.""" | |
try: | |
if not search_results: | |
search_results = ["No search results available."] | |
if not file_results: | |
file_results = "No file results available." | |
context = "\n".join([str(r) for r in search_results]) + "\n" + file_results | |
prompt = ChatPromptTemplate.from_messages([ | |
SystemMessage(content="""You are an assistant answering questions using provided context. | |
- Use ONLY the context to formulate a concise, accurate answer. | |
- If the context is insufficient, state: 'Insufficient information to answer.' | |
- Do NOT generate or assume information beyond the context. | |
- Return a single, clear sentence or phrase as the answer."""), | |
HumanMessage(content=f"Context: {context}\nQuestion: {question}") | |
]) | |
messages = [ | |
{"role": "system", "content": prompt[0].content}, | |
{"role": "user", "content": prompt[1].content} | |
] | |
if isinstance(llm_client, tuple): # hf_local | |
model, tokenizer = llm_client | |
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device) | |
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
elif hasattr(llm_client, "chat"): # together | |
response = llm_client.chat.completions.create( | |
model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", | |
messages=messages, | |
max_tokens=100, | |
temperature=0.7, | |
top_p=0.9, | |
frequency_penalty=0.5 | |
) | |
response = response.choices[0].message.content.strip() | |
else: # hf_api | |
response = llm_client.chat.completions.create( | |
messages=messages, | |
max_tokens=100, | |
temperature=0.7 | |
) | |
response = response.choices[0].message.content.strip() | |
answer = response.strip() | |
if not answer or answer.lower() == "none": | |
answer = "Insufficient information to answer." | |
logger.info(f"Task {task_id}: Generated answer: {answer}") | |
return answer | |
except Exception as e: | |
logger.error(f"Task {task_id}: Answer generation failed: {e}") | |
return "Error generating answer." |