Spaces:
Runtime error
Runtime error
import asyncio | |
import json | |
import logging | |
import aiohttp | |
from langchain import LLMChain | |
from langchain.llms.base import BaseLLM | |
from langchain.output_parsers import OutputFixingParser, PydanticOutputParser | |
from langchain.prompts import load_prompt | |
from pydantic import BaseModel, Field | |
from hugginggpt.exceptions import ModelSelectionException, async_wrap_exceptions | |
from hugginggpt.model_scraper import get_top_k_models | |
from hugginggpt.resources import get_prompt_resource | |
from hugginggpt.task_parsing import Task | |
logger = logging.getLogger(__name__) | |
class Model(BaseModel): | |
id: str = Field(description="ID of the model") | |
reason: str = Field(description="Reason for selecting this model") | |
async def select_hf_models( | |
user_input: str, | |
tasks: list[Task], | |
model_selection_llm: BaseLLM, | |
output_fixing_llm: BaseLLM, | |
) -> dict[int, Model]: | |
"""Use LLM agent to select the best available HuggingFace model for each task, given model metadata. | |
Runs concurrently.""" | |
async with aiohttp.ClientSession() as session: | |
async with asyncio.TaskGroup() as tg: | |
aio_tasks = [] | |
for task in tasks: | |
aio_tasks.append( | |
tg.create_task( | |
select_model( | |
user_input=user_input, | |
task=task, | |
model_selection_llm=model_selection_llm, | |
output_fixing_llm=output_fixing_llm, | |
session=session, | |
) | |
) | |
) | |
results = await asyncio.gather(*aio_tasks) | |
return {task_id: model for task_id, model in results} | |
async def select_model( | |
user_input: str, | |
task: Task, | |
model_selection_llm: BaseLLM, | |
output_fixing_llm: BaseLLM, | |
session: aiohttp.ClientSession, | |
) -> (int, Model): | |
logger.info(f"Starting model selection for task: {task.task}") | |
top_k_models = await get_top_k_models( | |
task=task.task, top_k=5, max_description_length=100, session=session | |
) | |
if task.task in [ | |
"summarization", | |
"translation", | |
"conversational", | |
"text-generation", | |
"text2text-generation", | |
]: | |
model = Model( | |
id="openai", | |
reason="Text generation tasks are best handled by OpenAI models", | |
) | |
else: | |
prompt_template = load_prompt( | |
get_prompt_resource("model-selection-prompt.json") | |
) | |
llm_chain = LLMChain(prompt=prompt_template, llm=model_selection_llm) | |
# Need to replace double quotes with single quotes for correct response generation | |
task_str = task.json().replace('"', "'") | |
models_str = json.dumps(top_k_models).replace('"', "'") | |
output = await llm_chain.apredict( | |
user_input=user_input, task=task_str, models=models_str, stop=["<im_end>"] | |
) | |
logger.debug(f"Model selection raw output: {output}") | |
parser = PydanticOutputParser(pydantic_object=Model) | |
fixing_parser = OutputFixingParser.from_llm( | |
parser=parser, llm=output_fixing_llm | |
) | |
model = fixing_parser.parse(output) | |
logger.info(f"For task: {task.task}, selected model: {model}") | |
return task.id, model | |