File size: 3,377 Bytes
b3d3593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import asyncio
import json
import logging

import aiohttp
from langchain import LLMChain
from langchain.llms.base import BaseLLM
from langchain.output_parsers import OutputFixingParser, PydanticOutputParser
from langchain.prompts import load_prompt
from pydantic import BaseModel, Field

from hugginggpt.exceptions import ModelSelectionException, async_wrap_exceptions
from hugginggpt.model_scraper import get_top_k_models
from hugginggpt.resources import get_prompt_resource
from hugginggpt.task_parsing import Task

logger = logging.getLogger(__name__)


class Model(BaseModel):
    id: str = Field(description="ID of the model")
    reason: str = Field(description="Reason for selecting this model")


async def select_hf_models(
    user_input: str,
    tasks: list[Task],
    model_selection_llm: BaseLLM,
    output_fixing_llm: BaseLLM,
) -> dict[int, Model]:
    """Use LLM agent to select the best available HuggingFace model for each task, given model metadata.
    Runs concurrently."""
    async with aiohttp.ClientSession() as session:
        async with asyncio.TaskGroup() as tg:
            aio_tasks = []
            for task in tasks:
                aio_tasks.append(
                    tg.create_task(
                        select_model(
                            user_input=user_input,
                            task=task,
                            model_selection_llm=model_selection_llm,
                            output_fixing_llm=output_fixing_llm,
                            session=session,
                        )
                    )
                )
        results = await asyncio.gather(*aio_tasks)
        return {task_id: model for task_id, model in results}


@async_wrap_exceptions(ModelSelectionException, "Failed to select model")
async def select_model(
    user_input: str,
    task: Task,
    model_selection_llm: BaseLLM,
    output_fixing_llm: BaseLLM,
    session: aiohttp.ClientSession,
) -> (int, Model):
    logger.info(f"Starting model selection for task: {task.task}")

    top_k_models = await get_top_k_models(
        task=task.task, top_k=5, max_description_length=100, session=session
    )

    if task.task in [
        "summarization",
        "translation",
        "conversational",
        "text-generation",
        "text2text-generation",
    ]:
        model = Model(
            id="openai",
            reason="Text generation tasks are best handled by OpenAI models",
        )
    else:
        prompt_template = load_prompt(
            get_prompt_resource("model-selection-prompt.json")
        )
        llm_chain = LLMChain(prompt=prompt_template, llm=model_selection_llm)
        # Need to replace double quotes with single quotes for correct response generation
        task_str = task.json().replace('"', "'")
        models_str = json.dumps(top_k_models).replace('"', "'")
        output = await llm_chain.apredict(
            user_input=user_input, task=task_str, models=models_str, stop=["<im_end>"]
        )
        logger.debug(f"Model selection raw output: {output}")

        parser = PydanticOutputParser(pydantic_object=Model)
        fixing_parser = OutputFixingParser.from_llm(
            parser=parser, llm=output_fixing_llm
        )
        model = fixing_parser.parse(output)

    logger.info(f"For task: {task.task}, selected model: {model}")
    return task.id, model