Spaces:
Runtime error
Runtime error
File size: 4,049 Bytes
b3d3593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import asyncio
import json
import logging
import click
import requests
from dotenv import load_dotenv
from hugginggpt import generate_response, infer, plan_tasks
from hugginggpt.history import ConversationHistory
from hugginggpt.llm_factory import LLMs, create_llms
from hugginggpt.log import setup_logging
from hugginggpt.model_inference import TaskSummary
from hugginggpt.model_selection import select_hf_models
from hugginggpt.response_generation import format_response
load_dotenv()
setup_logging()
logger = logging.getLogger(__name__)
@click.command()
@click.option("-p", "--prompt", type=str, help="Prompt for huggingGPT")
def main(prompt):
_print_banner()
llms = create_llms()
if prompt:
standalone_mode(user_input=prompt, llms=llms)
else:
interactive_mode(llms=llms)
def standalone_mode(user_input: str, llms: LLMs) -> str:
try:
response, task_summaries = compute(
user_input=user_input,
history=ConversationHistory(),
llms=llms,
)
pretty_response = format_response(response)
print(pretty_response)
return pretty_response
except Exception as e:
logger.exception("")
print(
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists."
)
def interactive_mode(llms: LLMs):
print("Please enter your request. End the conversation with 'exit'")
history = ConversationHistory()
while True:
try:
user_input = click.prompt("User")
if user_input.lower() == "exit":
break
logger.info(f"User input: {user_input}")
response, task_summaries = compute(
user_input=user_input,
history=history,
llms=llms,
)
pretty_response = format_response(response)
print(f"Assistant:{pretty_response}")
history.add(role="user", content=user_input)
history.add(role="assistant", content=response)
except Exception as e:
logger.exception("")
print(
f"Sorry, encountered error: {e}. Please try again. Check logs if problem persists."
)
def compute(
user_input: str,
history: ConversationHistory,
llms: LLMs,
) -> (str, list[TaskSummary]):
tasks = plan_tasks(
user_input=user_input, history=history, llm=llms.task_planning_llm
)
sorted(tasks, key=lambda t: max(t.dep))
logger.info(f"Sorted tasks: {tasks}")
hf_models = asyncio.run(
select_hf_models(
user_input=user_input,
tasks=tasks,
model_selection_llm=llms.model_selection_llm,
output_fixing_llm=llms.output_fixing_llm,
)
)
task_summaries = []
with requests.Session() as session:
for task in tasks:
logger.info(f"Starting task: {task}")
if task.depends_on_generated_resources():
task = task.replace_generated_resources(task_summaries=task_summaries)
model = hf_models[task.id]
inference_result = infer(
task=task,
model_id=model.id,
llm=llms.model_inference_llm,
session=session,
)
task_summaries.append(
TaskSummary(
task=task,
model=model,
inference_result=json.dumps(inference_result),
)
)
logger.info(f"Finished task: {task}")
logger.info("Finished all tasks")
logger.debug(f"Task summaries: {task_summaries}")
response = generate_response(
user_input=user_input,
task_summaries=task_summaries,
llm=llms.response_generation_llm,
)
return response, task_summaries
def _print_banner():
with open("resources/banner.txt", "r") as f:
banner = f.read()
logger.info("\n" + banner)
if __name__ == "__main__":
main()
|