Spaces:
Runtime error
Runtime error
File size: 5,658 Bytes
b3d3593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import copy
import logging
from pydantic import BaseModel, Field
from hugginggpt.exceptions import TaskParsingException, wrap_exceptions
logger = logging.getLogger(__name__)
GENERATED_TOKEN = "<GENERATED>"
class Task(BaseModel):
# This field is called 'task' and not 'name' to help with prompt engineering
task: str = Field(description="Name of the Machine Learning task")
id: int = Field(description="ID of the task")
dep: list[int] = Field(
description="List of IDs of the tasks that this task depends on"
)
args: dict[str, str] = Field(description="Arguments for the task")
def depends_on_generated_resources(self) -> bool:
"""Returns True if the task args contains <GENERATED> placeholder tokens, False otherwise"""
return self.dep != [-1] and any(
GENERATED_TOKEN in v for v in self.args.values()
)
@wrap_exceptions(TaskParsingException, "Failed to replace generated resources")
def replace_generated_resources(self, task_summaries: list):
"""Replaces <GENERATED> placeholder tokens in args with the generated resources from the task summaries"""
logger.info("Replacing generated resources")
generated_resources = {
k: parse_task_id(v) for k, v in self.args.items() if GENERATED_TOKEN in v
}
logger.info(
f"Resources to replace, resource type -> task id: {generated_resources}"
)
for resource_type, task_id in generated_resources.items():
matches = [
v
for k, v in task_summaries[task_id].inference_result.items()
if self.is_matching_generated_resource(k, resource_type)
]
if len(matches) == 1:
logger.info(
f"Match for generated {resource_type} in inference result of task {task_id}"
)
generated_resource = matches[0]
logger.info(f"Replacing {resource_type} with {generated_resource}")
self.args[resource_type] = generated_resource
return self
else:
raise Exception(
f"Cannot find unique required generated {resource_type} in inference result of task {task_id}"
)
def is_matching_generated_resource(self, arg_key: str, resource_type: str) -> bool:
"""Returns True if arg_key contains generated resource of the correct type"""
# If text, then match all arg keys that contain "text"
if resource_type.startswith("text"):
return "text" in arg_key
# If not text, then arg key must start with "generated" and the correct resource type
else:
return arg_key.startswith("generated " + resource_type)
class Tasks(BaseModel):
__root__: list[Task] = Field(description="List of Machine Learning tasks")
def __iter__(self):
return iter(self.__root__)
def __getitem__(self, item):
return self.__root__[item]
def __len__(self):
return len(self.__root__)
@wrap_exceptions(TaskParsingException, "Failed to parse tasks")
def parse_tasks(tasks_str: str) -> list[Task]:
"""Parses tasks from task planning json string"""
if tasks_str == "[]":
raise ValueError("Task string empty, cannot parse")
logger.info(f"Parsing tasks string: {tasks_str}")
tasks_str = tasks_str.strip()
# Cannot use PydanticOutputParser because it fails when parsing top level list JSON string
tasks = Tasks.parse_raw(tasks_str)
# __root__ extracts list[Task] from Tasks object
tasks = unfold(tasks.__root__)
tasks = fix_dependencies(tasks)
logger.info(f"Parsed tasks: {tasks}")
return tasks
def parse_task_id(resource_str: str) -> int:
"""Parse task id from generated resource string, e.g. <GENERATED>-4 -> 4"""
return int(resource_str.split("-")[1])
def fix_dependencies(tasks: list[Task]) -> list[Task]:
"""Ignores parsed tasks dependencies, and instead infers from task arguments"""
for task in tasks:
task.dep = infer_deps_from_args(task)
return tasks
def infer_deps_from_args(task: Task) -> list[int]:
"""If GENERATED arg value, add to list of unique deps. If none, deps = [-1]"""
deps = [parse_task_id(v) for v in task.args.values() if GENERATED_TOKEN in v]
if not deps:
deps = [-1]
# deduplicate
return list(set(deps))
def unfold(tasks: list[Task]) -> list[Task]:
"""A folded task has several generated resources folded into a single argument"""
unfolded_tasks = []
for task in tasks:
folded_args = find_folded_args(task)
if folded_args:
unfolded_tasks.extend(split(task, folded_args))
else:
unfolded_tasks.append(task)
return unfolded_tasks
def split(task: Task, folded_args: tuple[str, str]) -> list[Task]:
"""Split folded task into two same tasks, but separated generated resource arguments"""
key, value = folded_args
generated_items = value.split(",")
split_tasks = []
for item in generated_items:
new_task = copy.deepcopy(task)
dep_task_id = parse_task_id(item)
new_task.dep = [dep_task_id]
new_task.args[key] = item.strip()
split_tasks.append(new_task)
return split_tasks
def find_folded_args(task: Task) -> tuple[str, str] | None:
"""Finds folded args, e.g: 'image': '<GENERATED>-1,<GENERATED>-2'"""
for key, value in task.args.items():
if value.count(GENERATED_TOKEN) > 1:
logger.debug(f"Task {task.id} is folded")
return key, value
return None
|