Spaces:
Runtime error
Runtime error
import copy | |
import logging | |
from pydantic import BaseModel, Field | |
from hugginggpt.exceptions import TaskParsingException, wrap_exceptions | |
logger = logging.getLogger(__name__) | |
GENERATED_TOKEN = "<GENERATED>" | |
class Task(BaseModel): | |
# This field is called 'task' and not 'name' to help with prompt engineering | |
task: str = Field(description="Name of the Machine Learning task") | |
id: int = Field(description="ID of the task") | |
dep: list[int] = Field( | |
description="List of IDs of the tasks that this task depends on" | |
) | |
args: dict[str, str] = Field(description="Arguments for the task") | |
def depends_on_generated_resources(self) -> bool: | |
"""Returns True if the task args contains <GENERATED> placeholder tokens, False otherwise""" | |
return self.dep != [-1] and any( | |
GENERATED_TOKEN in v for v in self.args.values() | |
) | |
def replace_generated_resources(self, task_summaries: list): | |
"""Replaces <GENERATED> placeholder tokens in args with the generated resources from the task summaries""" | |
logger.info("Replacing generated resources") | |
generated_resources = { | |
k: parse_task_id(v) for k, v in self.args.items() if GENERATED_TOKEN in v | |
} | |
logger.info( | |
f"Resources to replace, resource type -> task id: {generated_resources}" | |
) | |
for resource_type, task_id in generated_resources.items(): | |
matches = [ | |
v | |
for k, v in task_summaries[task_id].inference_result.items() | |
if self.is_matching_generated_resource(k, resource_type) | |
] | |
if len(matches) == 1: | |
logger.info( | |
f"Match for generated {resource_type} in inference result of task {task_id}" | |
) | |
generated_resource = matches[0] | |
logger.info(f"Replacing {resource_type} with {generated_resource}") | |
self.args[resource_type] = generated_resource | |
return self | |
else: | |
raise Exception( | |
f"Cannot find unique required generated {resource_type} in inference result of task {task_id}" | |
) | |
def is_matching_generated_resource(self, arg_key: str, resource_type: str) -> bool: | |
"""Returns True if arg_key contains generated resource of the correct type""" | |
# If text, then match all arg keys that contain "text" | |
if resource_type.startswith("text"): | |
return "text" in arg_key | |
# If not text, then arg key must start with "generated" and the correct resource type | |
else: | |
return arg_key.startswith("generated " + resource_type) | |
class Tasks(BaseModel): | |
__root__: list[Task] = Field(description="List of Machine Learning tasks") | |
def __iter__(self): | |
return iter(self.__root__) | |
def __getitem__(self, item): | |
return self.__root__[item] | |
def __len__(self): | |
return len(self.__root__) | |
def parse_tasks(tasks_str: str) -> list[Task]: | |
"""Parses tasks from task planning json string""" | |
if tasks_str == "[]": | |
raise ValueError("Task string empty, cannot parse") | |
logger.info(f"Parsing tasks string: {tasks_str}") | |
tasks_str = tasks_str.strip() | |
# Cannot use PydanticOutputParser because it fails when parsing top level list JSON string | |
tasks = Tasks.parse_raw(tasks_str) | |
# __root__ extracts list[Task] from Tasks object | |
tasks = unfold(tasks.__root__) | |
tasks = fix_dependencies(tasks) | |
logger.info(f"Parsed tasks: {tasks}") | |
return tasks | |
def parse_task_id(resource_str: str) -> int: | |
"""Parse task id from generated resource string, e.g. <GENERATED>-4 -> 4""" | |
return int(resource_str.split("-")[1]) | |
def fix_dependencies(tasks: list[Task]) -> list[Task]: | |
"""Ignores parsed tasks dependencies, and instead infers from task arguments""" | |
for task in tasks: | |
task.dep = infer_deps_from_args(task) | |
return tasks | |
def infer_deps_from_args(task: Task) -> list[int]: | |
"""If GENERATED arg value, add to list of unique deps. If none, deps = [-1]""" | |
deps = [parse_task_id(v) for v in task.args.values() if GENERATED_TOKEN in v] | |
if not deps: | |
deps = [-1] | |
# deduplicate | |
return list(set(deps)) | |
def unfold(tasks: list[Task]) -> list[Task]: | |
"""A folded task has several generated resources folded into a single argument""" | |
unfolded_tasks = [] | |
for task in tasks: | |
folded_args = find_folded_args(task) | |
if folded_args: | |
unfolded_tasks.extend(split(task, folded_args)) | |
else: | |
unfolded_tasks.append(task) | |
return unfolded_tasks | |
def split(task: Task, folded_args: tuple[str, str]) -> list[Task]: | |
"""Split folded task into two same tasks, but separated generated resource arguments""" | |
key, value = folded_args | |
generated_items = value.split(",") | |
split_tasks = [] | |
for item in generated_items: | |
new_task = copy.deepcopy(task) | |
dep_task_id = parse_task_id(item) | |
new_task.dep = [dep_task_id] | |
new_task.args[key] = item.strip() | |
split_tasks.append(new_task) | |
return split_tasks | |
def find_folded_args(task: Task) -> tuple[str, str] | None: | |
"""Finds folded args, e.g: 'image': '<GENERATED>-1,<GENERATED>-2'""" | |
for key, value in task.args.items(): | |
if value.count(GENERATED_TOKEN) > 1: | |
logger.debug(f"Task {task.id} is folded") | |
return key, value | |
return None | |