Spaces:
Runtime error
Runtime error
import json | |
import logging | |
from langchain import LLMChain | |
from langchain.llms.base import BaseLLM | |
from langchain.prompts import load_prompt | |
from hugginggpt.exceptions import ResponseGenerationException, wrap_exceptions | |
from hugginggpt.model_inference import TaskSummary | |
from hugginggpt.resources import get_prompt_resource, prepend_resource_dir | |
logger = logging.getLogger(__name__) | |
def generate_response( | |
user_input: str, task_summaries: list[TaskSummary], llm: BaseLLM | |
) -> str: | |
"""Use LLM agent to generate a response to the user's input, given task results.""" | |
logger.info("Starting response generation") | |
sorted_task_summaries = sorted(task_summaries, key=lambda ts: ts.task.id) | |
task_results_str = task_summaries_to_json(sorted_task_summaries) | |
prompt_template = load_prompt( | |
get_prompt_resource("response-generation-prompt.json") | |
) | |
llm_chain = LLMChain(prompt=prompt_template, llm=llm) | |
response = llm_chain.predict( | |
user_input=user_input, task_results=task_results_str, stop=["<im_end>"] | |
) | |
logger.info(f"Generated response: {response}") | |
return response | |
def format_response(response: str) -> str: | |
"""Format the response to be more readable for user.""" | |
response = response.strip() | |
response = prepend_resource_dir(response) | |
return response | |
def task_summaries_to_json(task_summaries: list[TaskSummary]) -> str: | |
dicts = [ts.dict() for ts in task_summaries] | |
return json.dumps(dicts) | |