File size: 5,658 Bytes
b3d3593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import copy
import logging

from pydantic import BaseModel, Field

from hugginggpt.exceptions import TaskParsingException, wrap_exceptions

logger = logging.getLogger(__name__)

GENERATED_TOKEN = "<GENERATED>"


class Task(BaseModel):
    # This field is called 'task' and not 'name' to help with prompt engineering
    task: str = Field(description="Name of the Machine Learning task")
    id: int = Field(description="ID of the task")
    dep: list[int] = Field(
        description="List of IDs of the tasks that this task depends on"
    )
    args: dict[str, str] = Field(description="Arguments for the task")

    def depends_on_generated_resources(self) -> bool:
        """Returns True if the task args contains <GENERATED> placeholder tokens, False otherwise"""
        return self.dep != [-1] and any(
            GENERATED_TOKEN in v for v in self.args.values()
        )

    @wrap_exceptions(TaskParsingException, "Failed to replace generated resources")
    def replace_generated_resources(self, task_summaries: list):
        """Replaces <GENERATED> placeholder tokens in args with the generated resources from the task summaries"""
        logger.info("Replacing generated resources")
        generated_resources = {
            k: parse_task_id(v) for k, v in self.args.items() if GENERATED_TOKEN in v
        }
        logger.info(
            f"Resources to replace, resource type -> task id: {generated_resources}"
        )
        for resource_type, task_id in generated_resources.items():
            matches = [
                v
                for k, v in task_summaries[task_id].inference_result.items()
                if self.is_matching_generated_resource(k, resource_type)
            ]
            if len(matches) == 1:
                logger.info(
                    f"Match for generated {resource_type} in inference result of task {task_id}"
                )
                generated_resource = matches[0]
                logger.info(f"Replacing {resource_type} with {generated_resource}")
                self.args[resource_type] = generated_resource
                return self
            else:
                raise Exception(
                    f"Cannot find unique required generated {resource_type} in inference result of task {task_id}"
                )

    def is_matching_generated_resource(self, arg_key: str, resource_type: str) -> bool:
        """Returns True if arg_key contains generated resource of the correct type"""
        # If text, then match all arg keys that contain "text"
        if resource_type.startswith("text"):
            return "text" in arg_key
        # If not text, then arg key must start with "generated" and the correct resource type
        else:
            return arg_key.startswith("generated " + resource_type)


class Tasks(BaseModel):
    __root__: list[Task] = Field(description="List of Machine Learning tasks")

    def __iter__(self):
        return iter(self.__root__)

    def __getitem__(self, item):
        return self.__root__[item]

    def __len__(self):
        return len(self.__root__)


@wrap_exceptions(TaskParsingException, "Failed to parse tasks")
def parse_tasks(tasks_str: str) -> list[Task]:
    """Parses tasks from task planning json string"""
    if tasks_str == "[]":
        raise ValueError("Task string empty, cannot parse")
    logger.info(f"Parsing tasks string: {tasks_str}")
    tasks_str = tasks_str.strip()
    # Cannot use PydanticOutputParser because it fails when parsing top level list JSON string
    tasks = Tasks.parse_raw(tasks_str)
    # __root__ extracts list[Task] from Tasks object
    tasks = unfold(tasks.__root__)
    tasks = fix_dependencies(tasks)
    logger.info(f"Parsed tasks: {tasks}")
    return tasks


def parse_task_id(resource_str: str) -> int:
    """Parse task id from generated resource string, e.g. <GENERATED>-4 -> 4"""
    return int(resource_str.split("-")[1])


def fix_dependencies(tasks: list[Task]) -> list[Task]:
    """Ignores parsed tasks dependencies, and instead infers from task arguments"""
    for task in tasks:
        task.dep = infer_deps_from_args(task)
    return tasks


def infer_deps_from_args(task: Task) -> list[int]:
    """If GENERATED arg value, add to list of unique deps. If none, deps = [-1]"""
    deps = [parse_task_id(v) for v in task.args.values() if GENERATED_TOKEN in v]
    if not deps:
        deps = [-1]
    # deduplicate
    return list(set(deps))


def unfold(tasks: list[Task]) -> list[Task]:
    """A folded task has several generated resources folded into a single argument"""
    unfolded_tasks = []
    for task in tasks:
        folded_args = find_folded_args(task)
        if folded_args:
            unfolded_tasks.extend(split(task, folded_args))
        else:
            unfolded_tasks.append(task)
    return unfolded_tasks


def split(task: Task, folded_args: tuple[str, str]) -> list[Task]:
    """Split folded task into two same tasks, but separated generated resource arguments"""
    key, value = folded_args
    generated_items = value.split(",")
    split_tasks = []
    for item in generated_items:
        new_task = copy.deepcopy(task)
        dep_task_id = parse_task_id(item)
        new_task.dep = [dep_task_id]
        new_task.args[key] = item.strip()
        split_tasks.append(new_task)
    return split_tasks


def find_folded_args(task: Task) -> tuple[str, str] | None:
    """Finds folded args, e.g: 'image': '<GENERATED>-1,<GENERATED>-2'"""
    for key, value in task.args.items():
        if value.count(GENERATED_TOKEN) > 1:
            logger.debug(f"Task {task.id} is folded")
            return key, value
    return None