diff --git a/notebook_dir/GeoAgent-20250427.ipynb b/notebook_dir/GeoAgent-20250427.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6c477131e0d4ee3b46ac7e46101eb0e0cf9eef7a --- /dev/null +++ b/notebook_dir/GeoAgent-20250427.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "24b811ea-7714-4c2d-947d-1fb1b2ec980d", + "metadata": { + "collapsed": false, + "deletable": false, + "editable": false, + "jupyter": { + "outputs_hidden": false + }}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-27 16:56:51.233 | INFO | metagpt_yusin.const:get_metagpt_yusin_package_root:29 - Package root set to /data\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "382ee0ae87fd4555959704f7f29aa79b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(VBox(children=(VBox(children=(Label(value='Select data sources and LLM models (or Submit defaul…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from metagpt_yusin.geoagent import GeoAgent; GeoAgent().default()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf9d8a3f-ee15-490b-8ecf-ba58836d2bc2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62f50c92-c051-4650-988b-c84b9f22a130", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "161e81a7-680a-431b-af6c-772fe827afd3", + "metadata": {}, + "source": [ + "# Decomposing the overall task into tasks!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3a17761-7ddd-47b8-91b5-40d4e382a816", + "metadata": {}, + "outputs": [], + "source": [ + "[\n", + " {\n", + " \"task_id\": \"1\",\n", + " \"dependent_task_ids\": [],\n", + " \"instruction\": \"Design ASCII rabbit art\",\n", + " \"task_type\": \"other\"\n", + " },\n", + " {\n", + " \"task_id\": \"2\",\n", + " \"dependent_task_ids\": [\"1\"],\n", + " \"instruction\": \"Create function to plot ASCII rabbit\",\n", + " \"task_type\": \"other\"\n", + " },\n", + " {\n", + " \"task_id\": \"3\",\n", + " \"dependent_task_ids\": [\"2\"],\n", + " \"instruction\": \"Test the plotting function\",\n", + " \"task_type\": \"other\"\n", + " }\n", + "]\n" + ] + }, + { + "cell_type": "markdown", + "id": "538d829b-61d1-42a2-8c9a-553c5ff45dfa", + "metadata": {}, + "source": [ + "# Here is the code part!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/__init__-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..5fab359249f276d27345b44354eb5740e4aac3cf --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/4/24 22:26 +# @Author : alexanderwu +# @File : __init__.py + +from metagpt_yusin import _compat as _ # noqa: F401 diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/_compat-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/_compat-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c442bd7ded67f56c5b76d27e0828702d5c6ced5b --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/_compat-checkpoint.py @@ -0,0 +1,23 @@ +import platform +import sys +import warnings + +if sys.implementation.name == "cpython" and platform.system() == "Windows": + import asyncio + + if sys.version_info[:2] == (3, 9): + from asyncio.proactor_events import _ProactorBasePipeTransport + + # https://github.com/python/cpython/pull/92842 + def pacth_del(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._sock.close() + + _ProactorBasePipeTransport.__del__ = pacth_del + + if sys.version_info >= (3, 9, 0): + from semantic_kernel.orchestration import sk_function as _ # noqa: F401 + + # caused by https://github.com/microsoft/semantic-kernel/pull/1416 + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/config2-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/config2-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c7232018472bbbe426b538cbbdd4e3e377fa10b5 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/config2-checkpoint.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : alexanderwu +@File : config2.py +""" +import os +from pathlib import Path +from typing import Dict, Iterable, List, Literal, Optional + +from pydantic import BaseModel, model_validator + +from metagpt_yusin.logs import logger +from metagpt_yusin.configs.browser_config import BrowserConfig +from metagpt_yusin.configs.llm_config import LLMConfig, LLMType +from metagpt_yusin.configs.mermaid_config import MermaidConfig +from metagpt_yusin.configs.redis_config import RedisConfig +from metagpt_yusin.configs.s3_config import S3Config +from metagpt_yusin.configs.search_config import SearchConfig +from metagpt_yusin.configs.workspace_config import WorkspaceConfig +from metagpt_yusin.const import CONFIG_ROOT, metagpt_yusin_ROOT +from metagpt_yusin.utils.yaml_model import YamlModel + +# list all varibles in LLMType +#LLMType_dict = LLMType.__dict__ + +class CLIParams(BaseModel): + """CLI parameters""" + + project_path: str = "" + project_name: str = "" + inc: bool = False + reqa_file: str = "" + max_auto_summarize_code: int = 0 + git_reinit: bool = False + + @model_validator(mode="after") + def check_project_path(self): + """Check project_path and project_name""" + if self.project_path: + self.inc = True + self.project_name = self.project_name or Path(self.project_path).name + return self + + +class Config(CLIParams, YamlModel): + """Configurations for metagpt_yusin""" + + # Key Parameters + llm: LLMConfig + + # Global Proxy. Will be used if llm.proxy is not set + proxy: str = "" + + # Tool Parameters + search: SearchConfig = SearchConfig() + browser: BrowserConfig = BrowserConfig() + mermaid: MermaidConfig = MermaidConfig() + + # Storage Parameters + s3: Optional[S3Config] = None + redis: Optional[RedisConfig] = None + + # Misc Parameters + repair_llm_output: bool = False + prompt_schema: Literal["json", "markdown", "raw"] = "json" + workspace: WorkspaceConfig = WorkspaceConfig() + enable_longterm_memory: bool = False + code_review_k_times: int = 2 + + # Will be removed in the future + metagpt_yusin_tti_url: str = "" + language: str = "English" + redis_key: str = "placeholder" + iflytek_app_id: str = "" + iflytek_api_secret: str = "" + iflytek_api_key: str = "" + azure_tts_subscription_key: str = "" + azure_tts_region: str = "" + _extra: dict = dict() # extra config dict + + @classmethod + def from_home(cls, path): + """Load config from ~/.metagpt_yusin/config2.yaml""" + pathname = CONFIG_ROOT / path + if not pathname.exists(): + return None + return Config.from_yaml_file(pathname) + + @classmethod + def default(cls): + """Load default config + - Priority: env < default_config_paths + - Inside default_config_paths, the latter one overwrites the former one + """ + + #default_config_paths: List[Path] = [ + # metagpt_yusin_ROOT / "config/config2.yaml", + # CONFIG_ROOT / "config2.yaml", + #] + + default_config_paths: List[Path] = [ + CONFIG_ROOT / "config2.yaml", + ] + + dicts = [dict(os.environ)] + dicts += [Config.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + config_init = Config(**final) + + + # appended new + if 'api_type' in os.environ: + if os.environ.get('api_type') == 'openai': + config_init.llm.api_type = LLMType.OPENAI + elif os.environ.get('api_type') == 'groq': + config_init.llm.api_type = LLMType.OPENAI + config_init.llm.base_url = 'https://api.groq.com/openai/v1' + elif os.environ.get('api_type') == 'openrouter': + config_init.llm.api_type = LLMType.OPENROUTER + config_init.llm.base_url = 'https://openrouter.ai/api/v1' + else: + logger.debug('The API Type is not supported!!') + else: + logger.debug('Provide your api type!!') + if 'model' in os.environ: + config_init.llm.model = os.environ.get('model') + else: + logger.debug('Provide your model!!') + if 'api_key' in os.environ: + config_init.llm.api_key = os.environ.get('api_key') + else: + logger.debug('Provide your api key!!') + + + return config_init + + @classmethod + def from_llm_config(cls, llm_config: dict): + """user config llm + example: + llm_config = {"api_type": "xxx", "api_key": "xxx", "model": "xxx"} + gpt4 = Config.from_llm_config(llm_config) + A = Role(name="A", profile="Democratic candidate", goal="Win the election", actions=[a1], watch=[a2], config=gpt4) + """ + llm_config = LLMConfig.model_validate(llm_config) + dicts = [dict(os.environ)] + dicts += [{"llm": llm_config}] + final = merge_dict(dicts) + return Config(**final) + + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + + @property + def extra(self): + return self._extra + + @extra.setter + def extra(self, value: dict): + self._extra = value + + def get_openai_llm(self) -> Optional[LLMConfig]: + """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" + if self.llm.api_type == LLMType.OPENAI: + return self.llm + return None + + def get_azure_llm(self) -> Optional[LLMConfig]: + """Get Azure LLMConfig by name. If no Azure, raise Exception""" + if self.llm.api_type == LLMType.AZURE: + return self.llm + return None + + +def merge_dict(dicts: Iterable[Dict]) -> Dict: + """Merge multiple dicts into one, with the latter dict overwriting the former""" + result = {} + for dictionary in dicts: + result.update(dictionary) + return result + + +config = Config.default() diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/const-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/const-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ee31adb647005e13ab109be0dd66ca578d9c2ed9 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/const-checkpoint.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/1 11:59 +@Author : alexanderwu +@File : const.py +@Modified By: mashenquan, 2023-11-1. According to Section 2.2.1 and 2.2.2 of RFC 116, added key definitions for + common properties in the Message. +@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135. +@Modified By: mashenquan, 2023/12/5. Add directories for code summarization.. +""" +import os +from pathlib import Path + +from loguru import logger + +import metagpt_yusin + + +def get_metagpt_yusin_package_root(): + """Get the root directory of the installed package.""" + package_root = Path(metagpt_yusin.__file__).parent.parent + for i in (".git", ".project_root", ".gitignore"): + if (package_root / i).exists(): + break + else: + package_root = Path.cwd() + + logger.info(f"Package root set to {str(package_root)}") + return package_root + + +def get_metagpt_yusin_root(): + """Get the project root directory.""" + # Check if a project root is specified in the environment variable + project_root_env = os.getenv("metagpt_yusin_PROJECT_ROOT") + if project_root_env: + project_root = Path(project_root_env) + logger.info(f"PROJECT_ROOT set from environment variable to {str(project_root)}") + else: + # Fallback to package root if no environment variable is set + project_root = get_metagpt_yusin_package_root() + return project_root + + +# metagpt_yusin PROJECT ROOT AND VARS +#CONFIG_ROOT = Path.home() / ".metagpt_yusin" +metagpt_yusin_ROOT = get_metagpt_yusin_root() # Dependent on metagpt_yusin_PROJECT_ROOT +DEFAULT_WORKSPACE_ROOT = metagpt_yusin_ROOT / "workspace" +CONFIG_ROOT = metagpt_yusin_ROOT / "metagpt_yusin/configs" + +EXAMPLE_PATH = metagpt_yusin_ROOT / "examples" +EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data" +DATA_PATH = metagpt_yusin_ROOT / "data" +TEST_DATA_PATH = metagpt_yusin_ROOT / "tests/data" +RESEARCH_PATH = DATA_PATH / "research" +TUTORIAL_PATH = DATA_PATH / "tutorial_docx" +INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" + +UT_PATH = DATA_PATH / "ut" +SWAGGER_PATH = UT_PATH / "files/api/" +UT_PY_PATH = UT_PATH / "files/ut/" +API_QUESTIONS_PATH = UT_PATH / "files/question/" + +SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + +TMP = metagpt_yusin_ROOT / "tmp" + +SOURCE_ROOT = metagpt_yusin_ROOT / "metagpt_yusin" +PROMPT_PATH = SOURCE_ROOT / "prompts" +SKILL_DIRECTORY = SOURCE_ROOT / "skills" +TOOL_SCHEMA_PATH = metagpt_yusin_ROOT / "metagpt_yusin/tools/schemas" +TOOL_LIBS_PATH = metagpt_yusin_ROOT / "metagpt_yusin/tools/libs" + +# REAL CONSTS + +MEM_TTL = 24 * 30 * 3600 + +MESSAGE_ROUTE_FROM = "sent_from" +MESSAGE_ROUTE_TO = "send_to" +MESSAGE_ROUTE_CAUSE_BY = "cause_by" +MESSAGE_META_ROLE = "role" +MESSAGE_ROUTE_TO_ALL = "" +MESSAGE_ROUTE_TO_NONE = "" + +REQUIREMENT_FILENAME = "requirement.txt" +BUGFIX_FILENAME = "bugfix.txt" +PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt" + +DOCS_FILE_REPO = "docs" +PRDS_FILE_REPO = "docs/prd" +SYSTEM_DESIGN_FILE_REPO = "docs/system_design" +TASK_FILE_REPO = "docs/task" +CODE_PLAN_AND_CHANGE_FILE_REPO = "docs/code_plan_and_change" +COMPETITIVE_ANALYSIS_FILE_REPO = "resources/competitive_analysis" +DATA_API_DESIGN_FILE_REPO = "resources/data_api_design" +SEQ_FLOW_FILE_REPO = "resources/seq_flow" +SYSTEM_DESIGN_PDF_FILE_REPO = "resources/system_design" +PRD_PDF_FILE_REPO = "resources/prd" +TASK_PDF_FILE_REPO = "resources/api_spec_and_task" +CODE_PLAN_AND_CHANGE_PDF_FILE_REPO = "resources/code_plan_and_change" +TEST_CODES_FILE_REPO = "tests" +TEST_OUTPUTS_FILE_REPO = "test_outputs" +CODE_SUMMARIES_FILE_REPO = "docs/code_summary" +CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary" +RESOURCES_FILE_REPO = "resources" +SD_OUTPUT_FILE_REPO = "resources/sd_output" +GRAPH_REPO_FILE_REPO = "docs/graph_repo" +VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db" +CLASS_VIEW_FILE_REPO = "docs/class_view" + +YAPI_URL = "http://yapi.deepwisdomai.com/" + +DEFAULT_LANGUAGE = "English" +DEFAULT_MAX_TOKENS = 1500 +COMMAND_TOKENS = 500 +BRAIN_MEMORY = "BRAIN_MEMORY" +SKILL_PATH = "SKILL_PATH" +SERPER_API_KEY = "SERPER_API_KEY" +DEFAULT_TOKEN_SIZE = 500 + +# format +BASE64_FORMAT = "base64" + +# REDIS +REDIS_KEY = "REDIS_KEY" + +# Message id +IGNORED_MESSAGE_ID = "0" + +# Class Relationship +GENERALIZATION = "Generalize" +COMPOSITION = "Composite" +AGGREGATION = "Aggregate" + +# Timeout +USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration. +LLM_API_TIMEOUT = 300 diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/context-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/context-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..03efb831109ab01811a2393eb8bc504c1ba573b6 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/context-checkpoint.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:32 +@Author : alexanderwu +@File : context.py +""" +import os +from pathlib import Path +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict + +from metagpt_yusin.config2 import Config +from metagpt_yusin.configs.llm_config import LLMConfig, LLMType +from metagpt_yusin.provider.base_llm import BaseLLM +from metagpt_yusin.provider.llm_provider_registry import create_llm_instance +from metagpt_yusin.utils.cost_manager import ( + CostManager, + FireworksCostManager, + TokenCostManager, +) +from metagpt_yusin.utils.git_repository import GitRepository +from metagpt_yusin.utils.project_repo import ProjectRepo + + +class AttrDict(BaseModel): + """A dict-like object that allows access to keys as attributes, compatible with Pydantic.""" + + model_config = ConfigDict(extra="allow") + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.__dict__.update(kwargs) + + def __getattr__(self, key): + return self.__dict__.get(key, None) + + def __setattr__(self, key, value): + self.__dict__[key] = value + + def __delattr__(self, key): + if key in self.__dict__: + del self.__dict__[key] + else: + raise AttributeError(f"No such attribute: {key}") + + def set(self, key, val: Any): + self.__dict__[key] = val + + def get(self, key, default: Any = None): + return self.__dict__.get(key, default) + + def remove(self, key): + if key in self.__dict__: + self.__delattr__(key) + + +class Context(BaseModel): + """Env context for metagpt_yusin""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + kwargs: AttrDict = AttrDict() + config: Config = Config.default() + + repo: Optional[ProjectRepo] = None + git_repo: Optional[GitRepository] = None + src_workspace: Optional[Path] = None + cost_manager: CostManager = CostManager() + + _llm: Optional[BaseLLM] = None + + def new_environ(self): + """Return a new os.environ object""" + env = os.environ.copy() + # i = self.options + # env.update({k: v for k, v in i.items() if isinstance(v, str)}) + return env + + def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: + """Return a CostManager instance""" + if llm_config.api_type == LLMType.FIREWORKS: + return FireworksCostManager() + elif llm_config.api_type == LLMType.OPEN_LLM: + return TokenCostManager() + else: + return self.cost_manager + + def llm(self) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + self._llm = create_llm_instance(self.config.llm) + if self._llm.cost_manager is None: + self._llm.cost_manager = self._select_costmanager(self.config.llm) + return self._llm + + def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + llm = create_llm_instance(llm_config) + if llm.cost_manager is None: + llm.cost_manager = self._select_costmanager(llm_config) + return llm + + def serialize(self) -> Dict[str, Any]: + """Serialize the object's attributes into a dictionary. + + Returns: + Dict[str, Any]: A dictionary containing serialized data. + """ + return { + "workdir": str(self.repo.workdir) if self.repo else "", + "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, + "cost_manager": self.cost_manager.model_dump_json(), + } + + def deserialize(self, serialized_data: Dict[str, Any]): + """Deserialize the given serialized data and update the object's attributes accordingly. + + Args: + serialized_data (Dict[str, Any]): A dictionary containing serialized data. + """ + if not serialized_data: + return + workdir = serialized_data.get("workdir") + if workdir: + self.git_repo = GitRepository(local_path=workdir, auto_init=True) + self.repo = ProjectRepo(self.git_repo) + src_workspace = self.git_repo.workdir / self.git_repo.workdir.name + if src_workspace.exists(): + self.src_workspace = src_workspace + kwargs = serialized_data.get("kwargs") + if kwargs: + for k, v in kwargs.items(): + self.kwargs.set(k, v) + cost_manager = serialized_data.get("cost_manager") + if cost_manager: + self.cost_manager.model_validate_json(cost_manager) diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/context_mixin-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/context_mixin-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ac252dbc4270bb167c40844abd406126b10241 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/context_mixin-checkpoint.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/11 17:25 +@Author : alexanderwu +@File : context_mixin.py +""" +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from metagpt_yusin.config2 import Config +from metagpt_yusin.context import Context +from metagpt_yusin.provider.base_llm import BaseLLM + + +class ContextMixin(BaseModel): + """Mixin class for context and config""" + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + # Pydantic has bug on _private_attr when using inheritance, so we use private_* instead + # - https://github.com/pydantic/pydantic/issues/7142 + # - https://github.com/pydantic/pydantic/issues/7083 + # - https://github.com/pydantic/pydantic/issues/7091 + + # Env/Role/Action will use this context as private context, or use self.context as public context + private_context: Optional[Context] = Field(default=None, exclude=True) + # Env/Role/Action will use this config as private config, or use self.context.config as public config + private_config: Optional[Config] = Field(default=None, exclude=True) + + # Env/Role/Action will use this llm as private llm, or use self.context._llm instance + private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) + + @model_validator(mode="after") + def validate_context_mixin_extra(self): + self._process_context_mixin_extra() + return self + + def _process_context_mixin_extra(self): + """Process the extra field""" + kwargs = self.model_extra or {} + self.set_context(kwargs.pop("context", None)) + self.set_config(kwargs.pop("config", None)) + self.set_llm(kwargs.pop("llm", None)) + + def set(self, k, v, override=False): + """Set attribute""" + if override or not self.__dict__.get(k): + self.__dict__[k] = v + + def set_context(self, context: Context, override=True): + """Set context""" + self.set("private_context", context, override) + + def set_config(self, config: Config, override=False): + """Set config""" + self.set("private_config", config, override) + if config is not None: + _ = self.llm # init llm + + def set_llm(self, llm: BaseLLM, override=False): + """Set llm""" + self.set("private_llm", llm, override) + + @property + def config(self) -> Config: + """Role config: role config > context config""" + if self.private_config: + return self.private_config + return self.context.config + + @config.setter + def config(self, config: Config) -> None: + """Set config""" + self.set_config(config) + + @property + def context(self) -> Context: + """Role context: role context > context""" + if self.private_context: + return self.private_context + return Context() + + @context.setter + def context(self, context: Context) -> None: + """Set context""" + self.set_context(context) + + @property + def llm(self) -> BaseLLM: + """Role llm: if not existed, init from role.config""" + # print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}") + if not self.private_llm: + self.private_llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm) + return self.private_llm + + @llm.setter + def llm(self, llm: BaseLLM) -> None: + """Set llm""" + self.private_llm = llm diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/geoagent-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/geoagent-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f90ee5e619a223ba4fa9db8acee03c42c79b5237 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/geoagent-checkpoint.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : yusin +@File : geoagent.py +""" + +import leafmap +from ipywidgets import DatePicker, Text, Button, Textarea, Tab, Box, VBox, HBox, Layout, Label, Password, Dropdown, FloatText +from IPython.display import display + +from metagpt_yusin.tasks import RunLLM +from metagpt_yusin.set_envs import SetLlmEnv + +class GeoAgent(SetLlmEnv, RunLLM): + + def default(self): + # 数据集输入面板 + data_set = VBox([GeoAgent.data_source, + HBox([GeoAgent.gee_key,GeoAgent.button_set2]) + ]) + + # LLM设置面板 + llm_set = VBox([ + HBox([GeoAgent.api_type, GeoAgent.model]), + HBox([GeoAgent.api_key, GeoAgent.button_set1]) + ]) + + #设置时间以及空间 + st_set = VBox([ + HBox([GeoAgent.start_date, GeoAgent.end_date]), # 时间选择框在一行 + HBox([GeoAgent.get_bounds_button, GeoAgent.bounds_label]) # 获取边界按钮和标签在一行 + ]) + + # Tab布局 + tab_nest = Tab() + tab_nest.children = [data_set, llm_set, st_set] + tab_nest.titles = ('Data Sources', 'LLM Models', 'Space & Time') + + + # 左侧的控件布局 + left_panel = VBox([ + Label(value="Select data sources and LLM models (or Submit default) and Space Time:"), + tab_nest, + Label(value="Start with your task which can be refined and built upon previous tasks:"), + VBox([GeoAgent.notice_text, GeoAgent.box_llm]) + ]) + + # 右侧的地图布局 + right_panel = VBox([self.m]) + + # 使用 HBox 来调整左侧面板和右侧地图的宽度比例 + layout_all = HBox([ + left_panel, + right_panel + ], layout=Layout( + width='100%', # 整个 HBox 占满屏幕宽度 + height='330px', # 设置高度为 900px 或其他你想要的值 + justify_content='space-between' + )) + + # 设置左侧面板宽度为 50%,右侧面板宽度为 50% + left_panel.layout = Layout(width='50%', height='100%') # 高度与父容器相同 + right_panel.layout = Layout(width='50%', height='100%') # 高度与父容器相同 + + # 显示最终布局 + display(layout_all) + + # 设置按钮事件 + GeoAgent.button_set1.on_click(self.set_env) + GeoAgent.button_set2.on_click(self.set_env2) + GeoAgent.button_submit.on_click(self._check_id, self.clean) + GeoAgent.button_cl.on_click(self.clean) + GeoAgent.button_abort.on_click(self.abort) + GeoAgent.get_bounds_button.on_click(self.get_bounds) diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/llm-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/llm-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7fcb47e7e8552d6f3173dfc97cea82287e4bfdcd --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/llm-checkpoint.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 14:45 +@Author : alexanderwu +@File : llm.py +""" +from typing import Optional + +from metagpt_yusin.configs.llm_config import LLMConfig +from metagpt_yusin.context import Context +from metagpt_yusin.provider.base_llm import BaseLLM + + +def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> BaseLLM: + """get the default llm provider if name is None""" + ctx = context or Context() + if llm_config is not None: + return ctx.llm_with_cost_manager_from_llm_config(llm_config) + return ctx.llm() diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/logs-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/logs-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3d5dbaab808b55b39f1a1ca677c1355f8169df --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/logs-checkpoint.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/6/1 12:41 +@Author : alexanderwu +@File : logs.py +""" + +import sys +from datetime import datetime + +from loguru import logger as _logger + +from metagpt_yusin.const import metagpt_yusin_ROOT + +_print_level = "INFO" + + +def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None): + """Adjust the log level to above level""" + global _print_level + _print_level = print_level + + current_date = datetime.now() + formatted_date = current_date.strftime("%Y%m%d") + log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name + + _logger.remove() + _logger.add(sys.stderr, level=print_level) + _logger.add(metagpt_yusin_ROOT / f"logs/{log_name}.txt", level=logfile_level) + return _logger + + +logger = define_log_level() + + +def log_llm_stream(msg): + _llm_stream_log(msg) + + +def set_llm_stream_logfunc(func): + global _llm_stream_log + _llm_stream_log = func + + +def _llm_stream_log(msg): + if _print_level in ["INFO"]: + print(msg, end="") diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/schema-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/schema-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7d5dfc82886e8f220dd6593ef9d9a289f765a078 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/schema-checkpoint.py @@ -0,0 +1,787 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/8 22:12 +@Author : alexanderwu +@File : schema.py +@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: + Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. +@Modified By: mashenquan, 2023/11/22. + 1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135. + 2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing + between actions. + 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. +""" + +from __future__ import annotations + +import asyncio +import json +import os.path +import uuid +from abc import ABC +from asyncio import Queue, QueueEmpty, wait_for +from json import JSONDecodeError +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + field_serializer, + field_validator, + model_serializer, + model_validator, +) + +from metagpt_yusin.const import ( + MESSAGE_ROUTE_CAUSE_BY, + MESSAGE_ROUTE_FROM, + MESSAGE_ROUTE_TO, + MESSAGE_ROUTE_TO_ALL, + PRDS_FILE_REPO, + SYSTEM_DESIGN_FILE_REPO, + TASK_FILE_REPO, +) +from metagpt_yusin.logs import logger +from metagpt_yusin.repo_parser import DotClassInfo +from metagpt_yusin.utils.common import any_to_str, any_to_str_set, import_class +from metagpt_yusin.utils.exceptions import handle_exception +from metagpt_yusin.utils.serialize import ( + actionoutout_schema_to_mapping, + actionoutput_mapping_to_str, + actionoutput_str_to_mapping, +) + + +class SerializationMixin(BaseModel, extra="forbid"): + """ + PolyMorphic subclasses Serialization / Deserialization Mixin + - First of all, we need to know that pydantic is not designed for polymorphism. + - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need + to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. + + More details: + - https://docs.pydantic.dev/latest/concepts/serialization/ + - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` + """ + + __is_polymorphic_base = False + __subclasses_map__ = {} + + @model_serializer(mode="wrap") + def __serialize_with_class_type__(self, default_serializer) -> Any: + # default serializer, then append the `__module_class_name` field and return + ret = default_serializer(self) + ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + return ret + + @model_validator(mode="wrap") + @classmethod + def __convert_to_real_type__(cls, value: Any, handler): + if isinstance(value, dict) is False: + return handler(value) + + # it is a dict so make sure to remove the __module_class_name + # because we don't allow extra keywords but want to ensure + # e.g Cat.model_validate(cat.model_dump()) works + class_full_name = value.pop("__module_class_name", None) + + # if it's not the polymorphic base we construct via default handler + if not cls.__is_polymorphic_base: + if class_full_name is None: + return handler(value) + elif str(cls) == f"": + return handler(value) + else: + # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") + pass + + # otherwise we lookup the correct polymorphic type and construct that + # instead + if class_full_name is None: + raise ValueError("Missing __module_class_name field") + + class_type = cls.__subclasses_map__.get(class_full_name, None) + + if class_type is None: + # TODO could try dynamic import + raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!") + + return class_type(**value) + + def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): + cls.__is_polymorphic_base = is_polymorphic_base + cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls + super().__init_subclass__(**kwargs) + + +class SimpleMessage(BaseModel): + content: str + role: str + + +class Document(BaseModel): + """ + Represents a document. + """ + + root_path: str = "" + filename: str = "" + content: str = "" + + def get_meta(self) -> Document: + """Get metadata of the document. + + :return: A new Document instance with the same root path and filename. + """ + + return Document(root_path=self.root_path, filename=self.filename) + + @property + def root_relative_path(self): + """Get relative path from root of git repository. + + :return: relative path from root of git repository. + """ + return os.path.join(self.root_path, self.filename) + + def __str__(self): + return self.content + + def __repr__(self): + return self.content + + +class Documents(BaseModel): + """A class representing a collection of documents. + + Attributes: + docs (Dict[str, Document]): A dictionary mapping document names to Document instances. + """ + + docs: Dict[str, Document] = Field(default_factory=dict) + + @classmethod + def from_iterable(cls, documents: Iterable[Document]) -> Documents: + """Create a Documents instance from a list of Document instances. + + :param documents: A list of Document instances. + :return: A Documents instance. + """ + + docs = {doc.filename: doc for doc in documents} + return Documents(docs=docs) + + def to_action_output(self) -> "ActionOutput": + """Convert to action output string. + + :return: A string representing action output. + """ + from metagpt_yusin.actions.action_output import ActionOutput + + return ActionOutput(content=self.model_dump_json(), instruct_content=self) + + +class Message(BaseModel): + """list[: ]""" + + id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135 + content: str + instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True) + role: str = "user" # system / user / assistant + cause_by: str = Field(default="", validate_default=True) + sent_from: str = Field(default="", validate_default=True) + send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) + + @field_validator("id", mode="before") + @classmethod + def check_id(cls, id: str) -> str: + return id if id else uuid.uuid4().hex + + @field_validator("instruct_content", mode="before") + @classmethod + def check_instruct_content(cls, ic: Any) -> BaseModel: + if ic and isinstance(ic, dict) and "class" in ic: + if "mapping" in ic: + # compatible with custom-defined ActionOutput + mapping = actionoutput_str_to_mapping(ic["mapping"]) + actionnode_class = import_class("ActionNode", "metagpt_yusin.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) + elif "module" in ic: + # subclasses of BaseModel + ic_obj = import_class(ic["class"], ic["module"]) + else: + raise KeyError("missing required key to init Message.instruct_content from dict") + ic = ic_obj(**ic["value"]) + return ic + + @field_validator("cause_by", mode="before") + @classmethod + def check_cause_by(cls, cause_by: Any) -> str: + return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt_yusin.actions.add_requirement")) + + @field_validator("sent_from", mode="before") + @classmethod + def check_sent_from(cls, sent_from: Any) -> str: + return any_to_str(sent_from if sent_from else "") + + @field_validator("send_to", mode="before") + @classmethod + def check_send_to(cls, send_to: Any) -> set: + return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL}) + + @field_serializer("send_to", mode="plain") + def ser_send_to(self, send_to: set) -> list: + return list(send_to) + + @field_serializer("instruct_content", mode="plain") + def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]: + ic_dict = None + if ic: + # compatible with custom-defined ActionOutput + schema = ic.model_json_schema() + ic_type = str(type(ic)) + if " str: + """For search""" + return self.content + + def to_dict(self) -> dict: + """Return a dict containing `role` and `content` for the LLM call.l""" + return {"role": self.role, "content": self.content} + + def dump(self) -> str: + """Convert the object to json string""" + return self.model_dump_json(exclude_none=True, warnings=False) + + @staticmethod + @handle_exception(exception_type=JSONDecodeError, default_return=None) + def load(val): + """Convert the json string to object.""" + + try: + m = json.loads(val) + id = m.get("id") + if "id" in m: + del m["id"] + msg = Message(**m) + if id: + msg.id = id + return msg + except JSONDecodeError as err: + logger.error(f"parse json failed: {val}, error:{err}") + return None + + +class UserMessage(Message): + """便于支持OpenAI的消息 + Facilitate support for OpenAI messages + """ + + def __init__(self, content: str): + super().__init__(content=content, role="user") + + +class SystemMessage(Message): + """便于支持OpenAI的消息 + Facilitate support for OpenAI messages + """ + + def __init__(self, content: str): + super().__init__(content=content, role="system") + + +class AIMessage(Message): + """便于支持OpenAI的消息 + Facilitate support for OpenAI messages + """ + + def __init__(self, content: str): + super().__init__(content=content, role="assistant") + + +class Task(BaseModel): + task_id: str = "" + dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task + instruction: str = "" + task_type: str = "" + code: str = "" + result: str = "" + is_success: bool = False + is_finished: bool = False + + def reset(self): + self.code = "" + self.result = "" + self.is_success = False + self.is_finished = False + + def update_task_result(self, task_result: TaskResult): + self.code = task_result.code + self.result = task_result.result + self.is_success = task_result.is_success + + +class TaskResult(BaseModel): + """Result of taking a task, with result and is_success required to be filled""" + + code: str = "" + result: str + is_success: bool + + +class Plan(BaseModel): + goal: str + context: str = "" + tasks: list[Task] = [] + task_map: dict[str, Task] = {} + current_task_id: str = "" + + def _topological_sort(self, tasks: list[Task]): + task_map = {task.task_id: task for task in tasks} + dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks} + sorted_tasks = [] + visited = set() + + def visit(task_id): + if task_id in visited: + return + visited.add(task_id) + for dependent_id in dependencies.get(task_id, []): + visit(dependent_id) + sorted_tasks.append(task_map[task_id]) + + for task in tasks: + visit(task.task_id) + + return sorted_tasks + + def add_tasks(self, tasks: list[Task]): + """ + Integrates new tasks into the existing plan, ensuring dependency order is maintained. + + This method performs two primary functions based on the current state of the task list: + 1. If there are no existing tasks, it topologically sorts the provided tasks to ensure + correct execution order based on dependencies, and sets these as the current tasks. + 2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains + any common prefix of tasks (based on task_id and instruction) and appends the remainder + of the new tasks. The current task is updated to the first unfinished task in this merged list. + + Args: + tasks (list[Task]): A list of tasks (may be unordered) to add to the plan. + + Returns: + None: The method updates the internal state of the plan but does not return anything. + """ + if not tasks: + return + + # Topologically sort the new tasks to ensure correct dependency order + new_tasks = self._topological_sort(tasks) + + if not self.tasks: + # If there are no existing tasks, set the new tasks as the current tasks + self.tasks = new_tasks + + else: + # Find the length of the common prefix between existing and new tasks + prefix_length = 0 + for old_task, new_task in zip(self.tasks, new_tasks): + if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction: + break + prefix_length += 1 + + # Combine the common prefix with the remainder of the new tasks + final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:] + self.tasks = final_tasks + + # Update current_task_id to the first unfinished task in the merged list + self._update_current_task() + + # Update the task map for quick access to tasks by ID + self.task_map = {task.task_id: task for task in self.tasks} + + def reset_task(self, task_id: str): + """ + Clear code and result of the task based on task_id, and set the task as unfinished. + + Args: + task_id (str): The ID of the task to be reset. + + Returns: + None + """ + if task_id in self.task_map: + task = self.task_map[task_id] + task.reset() + + def replace_task(self, new_task: Task): + """ + Replace an existing task with the new input task based on task_id, and reset all tasks depending on it. + + Args: + new_task (Task): The new task that will replace an existing one. + + Returns: + None + """ + assert new_task.task_id in self.task_map + # Replace the task in the task map and the task list + self.task_map[new_task.task_id] = new_task + for i, task in enumerate(self.tasks): + if task.task_id == new_task.task_id: + self.tasks[i] = new_task + break + + # Reset dependent tasks + for task in self.tasks: + if new_task.task_id in task.dependent_task_ids: + self.reset_task(task.task_id) + + def append_task(self, new_task: Task): + """ + Append a new task to the end of existing task sequences + + Args: + new_task (Task): The new task to be appended to the existing task sequence + + Returns: + None + """ + assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" + + assert all( + [self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids] + ), "New task has unknown dependencies" + + # Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence + self.tasks.append(new_task) + self.task_map[new_task.task_id] = new_task + self._update_current_task() + + def has_task_id(self, task_id: str) -> bool: + return task_id in self.task_map + + def _update_current_task(self): + current_task_id = "" + for task in self.tasks: + if not task.is_finished: + current_task_id = task.task_id + break + self.current_task_id = current_task_id # all tasks finished + + @property + def current_task(self) -> Task: + """Find current task to execute + + Returns: + Task: the current task to be executed + """ + return self.task_map.get(self.current_task_id, None) + + def finish_current_task(self): + """Finish current task, set Task.is_finished=True, set current task to next task""" + if self.current_task_id: + self.current_task.is_finished = True + self._update_current_task() # set to next task + + def get_finished_tasks(self) -> list[Task]: + """return all finished tasks in correct linearized order + + Returns: + list[Task]: list of finished tasks + """ + return [task for task in self.tasks if task.is_finished] + + +class MessageQueue(BaseModel): + """Message queue which supports asynchronous updates.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + _queue: Queue = PrivateAttr(default_factory=Queue) + + def pop(self) -> Message | None: + """Pop one message from the queue.""" + try: + item = self._queue.get_nowait() + if item: + self._queue.task_done() + return item + except QueueEmpty: + return None + + def pop_all(self) -> List[Message]: + """Pop all messages from the queue.""" + ret = [] + while True: + msg = self.pop() + if not msg: + break + ret.append(msg) + return ret + + def push(self, msg: Message): + """Push a message into the queue.""" + self._queue.put_nowait(msg) + + def empty(self): + """Return true if the queue is empty.""" + return self._queue.empty() + + async def dump(self) -> str: + """Convert the `MessageQueue` object to a json string.""" + if self.empty(): + return "[]" + + lst = [] + msgs = [] + try: + while True: + item = await wait_for(self._queue.get(), timeout=1.0) + if item is None: + break + msgs.append(item) + lst.append(item.dump()) + self._queue.task_done() + except asyncio.TimeoutError: + logger.debug("Queue is empty, exiting...") + finally: + for m in msgs: + self._queue.put_nowait(m) + return json.dumps(lst, ensure_ascii=False) + + @staticmethod + def load(data) -> "MessageQueue": + """Convert the json string to the `MessageQueue` object.""" + queue = MessageQueue() + try: + lst = json.loads(data) + for i in lst: + msg = Message.load(i) + queue.push(msg) + except JSONDecodeError as e: + logger.warning(f"JSON load failed: {data}, error:{e}") + + return queue + + +# 定义一个泛型类型变量 +T = TypeVar("T", bound="BaseModel") + + +class BaseContext(BaseModel, ABC): + @classmethod + @handle_exception + def loads(cls: Type[T], val: str) -> Optional[T]: + i = json.loads(val) + return cls(**i) + + +class CodingContext(BaseContext): + filename: str + design_doc: Optional[Document] = None + task_doc: Optional[Document] = None + code_doc: Optional[Document] = None + code_plan_and_change_doc: Optional[Document] = None + + +class TestingContext(BaseContext): + filename: str + code_doc: Document + test_doc: Optional[Document] = None + + +class RunCodeContext(BaseContext): + mode: str = "script" + code: Optional[str] = None + code_filename: str = "" + test_code: Optional[str] = None + test_filename: str = "" + command: List[str] = Field(default_factory=list) + working_directory: str = "" + additional_python_paths: List[str] = Field(default_factory=list) + output_filename: Optional[str] = None + output: Optional[str] = None + + +class RunCodeResult(BaseContext): + summary: str + stdout: str + stderr: str + + +class CodeSummarizeContext(BaseModel): + design_filename: str = "" + task_filename: str = "" + codes_filenames: List[str] = Field(default_factory=list) + reason: str = "" + + @staticmethod + def loads(filenames: List) -> CodeSummarizeContext: + ctx = CodeSummarizeContext() + for filename in filenames: + if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): + ctx.design_filename = str(filename) + continue + if Path(filename).is_relative_to(TASK_FILE_REPO): + ctx.task_filename = str(filename) + continue + return ctx + + def __hash__(self): + return hash((self.design_filename, self.task_filename)) + + +class BugFixContext(BaseContext): + filename: str = "" + + +class CodePlanAndChangeContext(BaseModel): + requirement: str = "" + issue: str = "" + prd_filename: str = "" + design_filename: str = "" + task_filename: str = "" + + @staticmethod + def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext: + ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""), issue=kwargs.get("issue", "")) + for filename in filenames: + filename = Path(filename) + if filename.is_relative_to(PRDS_FILE_REPO): + ctx.prd_filename = filename.name + continue + if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO): + ctx.design_filename = filename.name + continue + if filename.is_relative_to(TASK_FILE_REPO): + ctx.task_filename = filename.name + continue + return ctx + + +# mermaid class view +class UMLClassMeta(BaseModel): + name: str = "" + visibility: str = "" + + @staticmethod + def name_to_visibility(name: str) -> str: + if name == "__init__": + return "+" + if name.startswith("__"): + return "-" + elif name.startswith("_"): + return "#" + return "+" + + +class UMLClassAttribute(UMLClassMeta): + value_type: str = "" + default_value: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + if self.value_type: + content += self.value_type.replace(" ", "") + " " + name = self.name.split(":", 1)[1] if ":" in self.name else self.name + content += name + if self.default_value: + content += "=" + if self.value_type not in ["str", "string", "String"]: + content += self.default_value + else: + content += '"' + self.default_value.replace('"', "") + '"' + # if self.abstraction: + # content += "*" + # if self.static: + # content += "$" + return content + + +class UMLClassMethod(UMLClassMeta): + args: List[UMLClassAttribute] = Field(default_factory=list) + return_type: str = "" + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + self.visibility + name = self.name.split(":", 1)[1] if ":" in self.name else self.name + content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" + if self.return_type: + content += " " + self.return_type.replace(" ", "") + # if self.abstraction: + # content += "*" + # if self.static: + # content += "$" + return content + + +class UMLClassView(UMLClassMeta): + attributes: List[UMLClassAttribute] = Field(default_factory=list) + methods: List[UMLClassMethod] = Field(default_factory=list) + + def get_mermaid(self, align=1) -> str: + content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" + for v in self.attributes: + content += v.get_mermaid(align=align + 1) + "\n" + for v in self.methods: + content += v.get_mermaid(align=align + 1) + "\n" + content += "".join(["\t" for i in range(align)]) + "}\n" + return content + + @classmethod + def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView: + visibility = UMLClassView.name_to_visibility(dot_class_info.name) + class_view = cls(name=dot_class_info.name, visibility=visibility) + for i in dot_class_info.attributes.values(): + visibility = UMLClassAttribute.name_to_visibility(i.name) + attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_) + class_view.attributes.append(attr) + for i in dot_class_info.methods.values(): + visibility = UMLClassMethod.name_to_visibility(i.name) + method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_) + for j in i.args: + arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_) + method.args.append(arg) + method.return_type = i.return_args.type_ + class_view.methods.append(method) + return class_view diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/set_envs-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/set_envs-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..9867ab379327f41663e33d79b85cf3bee35876b7 --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/set_envs-checkpoint.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : alexanderwu +@File : config2.py +""" + +import os +from metagpt_yusin.config2 import Config +from ipywidgets import Tab, Label, Button, Textarea, VBox, HBox, Layout, Password, Dropdown +from IPython.display import display +from ipylab import JupyterFrontEnd +app = JupyterFrontEnd() + +def restart(): + # restart kernel + app.commands.execute('kernelmenu:restart-and-clear') + +class SetLlmEnv: + + api_type = Dropdown( + options=['openai', 'gemini', 'llama3', 'groq', 'openrouter'], + value='openrouter', #'groq', #'openai', + description='API Type:', + disabled=False, + layout=Layout(width='500px', height='30px')) + + model = Dropdown( + options=['gpt-3.5-turbo-1106','deepseek-r1-distill-llama-70b', 'tngtech/deepseek-r1t-chimera:free'], + value='tngtech/deepseek-r1t-chimera:free', #'deepseek-r1-distill-llama-70b', #'gpt-3.5-turbo-1106', + description='Model:', + disabled=False, + layout=Layout(width='500px', height='30px')) + + api_key = Password( + value='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',#'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', #'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', + placeholder='Enter API Key', + description='API Key :', + disabled=False, + layout=Layout(width='800px', height='30px')) + + + gee_key = Password( + value='sk-.....................................................................', + placeholder='Provide your GEE key here', + description='Data Key:', + disabled=False, + layout=Layout(width='800px', height='30px')) + + pc_key = Password( + value='sk-', + placeholder='Provide your Planetary Computer key here', + description='PC Key :', + disabled=False, + layout=Layout(width='800px', height='30px')) + + button_set1 = Button(description='Submit', layout=Layout(height='30px', width='200px')) + button_set2 = Button(description='Submit', layout=Layout(height='30px', width='200px')) + + + @staticmethod + def set_env(event): + os.environ['api_type'] = SetLlmEnv.api_type.value + os.environ['model'] = SetLlmEnv.model.value + os.environ['api_key'] = SetLlmEnv.api_key.value + app.commands.execute('notebook:move-cursor-down') + + @staticmethod + def set_env2(event): + os.environ['EARTHENGINE_TOKEN'] = SetLlmEnv.gee_key.value + os.environ['PC_SDK_SUBSCRIPTION_KEY'] = SetLlmEnv.pc_key.value + #os.environ['MLHub API Key'] = SetLlmEnv.mkhub_key.value + + @classmethod + def default(cls): + data_set = HBox([VBox([Label(value="Provide GEE key here:"), SetLlmEnv.gee_key]), + VBox([Label(value="Provide Planetary Computer key here:"), SetLlmEnv.pc_key]), + VBox([Label(value=" "), SetLlmEnv.button_set2])]) + llm_set = VBox([HBox([SetLlmEnv.api_type, SetLlmEnv.model]), HBox([SetLlmEnv.api_key, SetLlmEnv.button_set1])]) + tab_nest = Tab() + tab_nest.children = [data_set, llm_set] + tab_nest.titles = ('Data Sources', 'LLM Models') + + display(VBox([Label(value="Specify your preferred:"), tab_nest])) + SetLlmEnv.button_set1.on_click(cls.set_env) + SetLlmEnv.button_set2.on_click(cls.set_env2) \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/.ipynb_checkpoints/tasks-checkpoint.py b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/tasks-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a2be51268c0063a7b52cb2d97e2540caf7d857cd --- /dev/null +++ b/notebook_dir/metagpt_yusin/.ipynb_checkpoints/tasks-checkpoint.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : alexanderwu +@File : config2.py +""" +import leafmap +import asyncio +import nest_asyncio +import nbformat +from nbclient import NotebookClient +from ipylab import JupyterFrontEnd +from metagpt_yusin.logs import log_llm_stream, logger +from ipywidgets import DatePicker, Text, Button, Textarea, Tab, Box, VBox, HBox, Layout, Label, Password, Dropdown, FloatText +from traitlets import observe, link, Unicode, Bool, Any +from IPython.display import display +nest_asyncio.apply() +app = JupyterFrontEnd() + + +class ConfirmationButton(VBox): + button_style = Any(default_value='') + description = Unicode() + disabled = Bool() + icon = Unicode() + layout = Any() + style = Any() + tooltip = Unicode() + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._button = Button(**kwargs) + self._confirm_btn = Button(description='confirm', icon='check', button_style='success', layout=dict(height='45%', width='90%')) + self._cancel_btn = Button(description='cancel', icon='times', button_style='warning', layout=dict(height='45%', width='90%'), disabled=True) + self._button.on_click(self._on_btn_click) + self._cancel_btn.on_click(self._on_btn_click) + self._confirm_btn.on_click(self._on_btn_click) + self.children = [self._button] + + def on_click(self, *args, **kwargs): + self._confirm_btn.on_click(args[0], **kwargs) + self._cancel_btn.on_click(args[1], **kwargs) + + def _on_btn_click(self, b): + if b==self._button: + self.children = [self._confirm_btn, self._cancel_btn] + else: + self.children = [self._button] + + +class RunLLM: + # context + context = None + # overall tasks + tasks = [] + # widget + input_text = Textarea( + #description='Give your task description:', + placeholder='Give your task description, such as: run data analysis on sklearn Iris dataset, include a plot', + disabled=False, + layout=Layout(height='90px', width='400px') + ) + #button_submit = Button(description='Submit', layout=Layout(height='100px', width='100px')) + button_submit = ConfirmationButton(description='Submit', layout=Layout(height='30px', width='80px')) + button_cl = Button(description='Clean', disabled=False, layout=Layout(height='30px', width='80px')) + button_abort = Button(description='Abort', disabled=True, layout=Layout(height='30px', width='80px')) + box_llm = HBox([ + input_text, + VBox([button_submit, button_cl, button_abort]) + ]) + #Box(layout=Layout(width='500px', height='150px')) + #box_llm.children += (input_text, VBox([button_submit, button_cl, button_abort], layout=Layout(height='110px', width='80px'))) + # popup button + notice_text = Text(placeholder='Make sure whether you want to include the following cells as a context for your text, such as the refined task description and code of previous steps', + disabled=False, + layout=Layout(height='30px', width='500px')) + context = None + + # 创建地图,设置中心和缩放级别 + m = leafmap.Map(center=[37.6412, -122.1353], zoom=15, height="800px") + m.add_basemap("SATELLITE") + + # 时间数据选择框 + start_date = DatePicker(description='Start Date', layout=Layout(height='30px', width='50%')) + end_date = DatePicker(description='End Date', layout=Layout(height='30px', width='50%')) + data_source = Dropdown( + options=['Online Tile', 'GEE', 'NASA Earth Data', 'OSM', 'AWS', 'CDSE', 'MAXAR', 'Planetary Computer'], + value='Online Tile', #'groq', #'openai', + description='Data Source:', + disabled=False, + layout=Layout(height='30px', width='78%')) + + # 用户区域选择事件 + bounds_label = Label(value="", layout=Layout(height='30px', width='65%', border='1px solid gray', padding='5px', + display='flex', align_items='center', justify_content='center')) + get_bounds_button = Button(description="Get Selected Bounds", layout=Layout(height='30px', width='35%')) + + @staticmethod + def get_bounds(event): + # 每次点击按钮时都重新获取地图选区 + if RunLLM.m.user_roi is not None: + bbox = RunLLM.m.user_roi_bounds() # 获取当前选择的区域的边界框 + RunLLM.bounds_label.value = f"{bbox}" + else: + RunLLM.bounds_label.value = None + + @staticmethod + def clean(event): + RunLLM.notice_text.value = 'excluding and cleaning the following cells.' + #print('clean cells!') + app.commands.execute('notebook:move-cursor-down') + app.commands.execute('notebook:insert-cell-above') + for i in range(100): + #if app.commands.execute('notebook:merge-cell-below'): + # app.commands.execute('notebook:delete-cell') + app.commands.execute('notebook:merge-cell-below') + app.commands.execute('notebook:delete-cell') + # clean last cell + app.commands.execute('notebook:delete-cell') + # trun off clean after clean + RunLLM.button_cl.disabled = True + RunLLM.button_submit._cancel_btn.disabled = True + + @staticmethod + def abort(event): + RunLLM.button_cl.disabled = False + RunLLM.button_submit._cancel_btn.disabled = False + #print('abort!') + for task in RunLLM.tasks: + task.cancel() + RunLLM.tasks.clear() + # after abort trun off it + RunLLM.button_abort.disabled = True + + def _check_id(self, b): + RunLLM.notice_text.value = 'keep the following cells as the context of current step.' + RunLLM.button_abort.disabled = True + ''' + # get context ------------------------------- + app.commands.execute('docmanager:save') + app.commands.execute('notebook:insert-cell-below') + app.commands.execute('notebook:replace-selection', {'text': "from ipylab import JupyterFrontEnd\napp = JupyterFrontEnd()\nnb_path = app.sessions.sessions[0]['path']"}) + app.commands.execute('notebook:run-cell') + #app.commands.execute('notebook:delete-cell') + print(app) + # ------------------------------------------- + ''' + nb_path = app.sessions.sessions[0]['path'] + # Load the notebook + with open(nb_path, 'r') as notebook_file: + notebook_content = notebook_file.read() + # Parse the notebook + notebook = nbformat.reads(notebook_content, as_version=4) + content = [] + for i_cell in notebook['cells']: + #if i_cell['cell_type'] == 'code': + content.append(i_cell['source']) + #print(content) + # get the first place of the given strings + try: + index_code = content.index('# Here is the code part!') + code = content[index_code+1:] + except: + code = None + try: + index_task = content.index('# Decomposing the overall task into tasks!') + task = content[index_task+1:index_code] + except: + task = None + # give context + if code is None or task is None: + index_context = content.index('from metagpt_yusin.geoagent import GeoAgent\nGeoAgent().default()') + RunLLM.context = '\nThis is the initial subtasks and codes:\n' + '\n'.join(content[index_context+1:]) + '\nNow please further improve these initial subtasks and codes.' + else: + RunLLM.context = '\nThis is the initial subtasks:\n' + '\n'.join(task) + '\nAnd the initial generated codes:\n' + '\n'.join(code) + '\nNow please further improve these initial subtasks and codes.' + #print(context) + + # 检查时间范围 + if not RunLLM.start_date.value or not RunLLM.end_date.value: + RunLLM.notice_text.value = "!!Please select a valid start and end date before submitting." + raise ValueError("Please select a valid start and end date before submitting.") + RunLLM.context += f"\nTime Range: {RunLLM.start_date.value} to {RunLLM.end_date.value}" + + # 检查区域范围 + if not RunLLM.bounds_label.value: + RunLLM.notice_text.value = "!!Please select a valid bounding box (area) before submitting." + raise ValueError("Please select a valid bounding box (area) before submitting.") + RunLLM.context += f"\nBounds: {RunLLM.bounds_label.value}" + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.start_job()) + RunLLM.tasks.clear() + except KeyboardInterrupt: + print("User termination detected") + finally: + loop.close() + + async def async_run_task(self, ): + from metagpt_yusin.logs import logger + from metagpt_yusin.roles.di.data_interpreter import DataInterpreter + from metagpt_yusin.utils.recovery_util import save_history + RunLLM.button_submit._confirm_btn.on_click(self._check_id, remove=True) + #print('running job') + #goal = 'This is the overall task: ' + goal + context + #print(input_text.value) + output = await DataInterpreter.jynb_run(RunLLM.input_text.value, RunLLM.context) + tasks, tasks_results, di = output[0], output[1], output[2] + logger.info('-----------------------------------') + logger.info(tasks) + logger.info(tasks_results) + save_history(role=di) + #print('job complete!') + + async def run_task(self, ): + # trun off clean during runing + RunLLM.button_cl.disabled = True + RunLLM.button_submit._cancel_btn.disabled = True + # trun on abort during runing + RunLLM.button_abort.disabled = False + await self.async_run_task() + # trun on clean afetr one iter runing + RunLLM.button_cl.disabled = False + RunLLM.button_submit._cancel_btn.disabled = False + # trun off abort during runing + RunLLM.button_abort.disabled = True + + def run_job(self, event): + task = asyncio.create_task(self.run_task()) + RunLLM.tasks.append(task) + + async def start_job(self,): + try: + RunLLM.button_submit.description='Submit' + RunLLM.button_submit._confirm_btn.on_click(self.run_job) + except KeyboardInterrupt: + # Handle keyboard interrupt (user termination) + print("User termination detected") + + def default(self,): + display(VBox([RunLLM.box_llm, RunLLM.notice_text])) + RunLLM.button_submit.on_click(self._check_id, self.clean) + RunLLM.button_cl.on_click(self.clean) + RunLLM.button_abort.on_click(self.abort) diff --git a/notebook_dir/metagpt_yusin/__init__.py b/notebook_dir/metagpt_yusin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fab359249f276d27345b44354eb5740e4aac3cf --- /dev/null +++ b/notebook_dir/metagpt_yusin/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/4/24 22:26 +# @Author : alexanderwu +# @File : __init__.py + +from metagpt_yusin import _compat as _ # noqa: F401 diff --git a/notebook_dir/metagpt_yusin/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5352d15cc84c995ddcd94406a1388a26e6a136 Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/_compat.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/_compat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c5d5227c368b68034ee158e049bf87b3128e8f Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/_compat.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/config2.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/config2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69145439f0cb7df61adb419ca9a234d287e925fa Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/config2.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/const.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/const.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3336dd12fa451071d57ce0ecc1790de44b6abcea Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/const.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/context.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1ae58d9caed09c56dcc05f7df8a7839b296a284 Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/context.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/context_mixin.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/context_mixin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c0e57ecb60dc64d4b12b8383bbcbd5d0d654c7 Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/context_mixin.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/geoagent.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/geoagent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0b7274f52ba0c2378a52db8d3f583d51fd79ad Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/geoagent.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/llm.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/llm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eccf95ec062b1d087037faa57518bfe936425404 Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/llm.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/logs.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/logs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6153868993d22ccaa95914cd57d86109f989d098 Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/logs.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/repo_parser.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/repo_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64ae4c3b6c8a6b59d34b675854cdb8b65851480b Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/repo_parser.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/schema.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd4d2b1d556f42204c42ec6fb828e20551873f4 Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/schema.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/__pycache__/tasks.cpython-39.pyc b/notebook_dir/metagpt_yusin/__pycache__/tasks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7bb55e1d6d3041fb92e61b23fada189e47ecc0f Binary files /dev/null and b/notebook_dir/metagpt_yusin/__pycache__/tasks.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/_compat.py b/notebook_dir/metagpt_yusin/_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..c442bd7ded67f56c5b76d27e0828702d5c6ced5b --- /dev/null +++ b/notebook_dir/metagpt_yusin/_compat.py @@ -0,0 +1,23 @@ +import platform +import sys +import warnings + +if sys.implementation.name == "cpython" and platform.system() == "Windows": + import asyncio + + if sys.version_info[:2] == (3, 9): + from asyncio.proactor_events import _ProactorBasePipeTransport + + # https://github.com/python/cpython/pull/92842 + def pacth_del(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._sock.close() + + _ProactorBasePipeTransport.__del__ = pacth_del + + if sys.version_info >= (3, 9, 0): + from semantic_kernel.orchestration import sk_function as _ # noqa: F401 + + # caused by https://github.com/microsoft/semantic-kernel/pull/1416 + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) diff --git a/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/__init__-checkpoint.py b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a15460b2d42de6a94d77a4d09adf4b7ed9404b0b --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:44 +@Author : alexanderwu +@File : __init__.py +""" +from enum import Enum + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.actions.action_output import ActionOutput +from metagpt_yusin.actions.add_requirement import UserRequirement +from metagpt_yusin.actions.debug_error import DebugError +from metagpt_yusin.actions.design_api import WriteDesign +from metagpt_yusin.actions.design_api_review import DesignReview +from metagpt_yusin.actions.project_management import WriteTasks +from metagpt_yusin.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch +from metagpt_yusin.actions.run_code import RunCode +from metagpt_yusin.actions.search_and_summarize import SearchAndSummarize +from metagpt_yusin.actions.write_code import WriteCode +from metagpt_yusin.actions.write_code_review import WriteCodeReview +from metagpt_yusin.actions.write_prd import WritePRD +from metagpt_yusin.actions.write_prd_review import WritePRDReview +from metagpt_yusin.actions.write_test import WriteTest +from metagpt_yusin.actions.di.execute_nb_code import ExecuteNbCode +from metagpt_yusin.actions.di.write_analysis_code import WriteAnalysisCode +from metagpt_yusin.actions.di.write_plan import WritePlan + + +class ActionType(Enum): + """All types of Actions, used for indexing.""" + + ADD_REQUIREMENT = UserRequirement + WRITE_PRD = WritePRD + WRITE_PRD_REVIEW = WritePRDReview + WRITE_DESIGN = WriteDesign + DESIGN_REVIEW = DesignReview + WRTIE_CODE = WriteCode + WRITE_CODE_REVIEW = WriteCodeReview + WRITE_TEST = WriteTest + RUN_CODE = RunCode + DEBUG_ERROR = DebugError + WRITE_TASKS = WriteTasks + SEARCH_AND_SUMMARIZE = SearchAndSummarize + COLLECT_LINKS = CollectLinks + WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize + CONDUCT_RESEARCH = ConductResearch + EXECUTE_NB_CODE = ExecuteNbCode + WRITE_ANALYSIS_CODE = WriteAnalysisCode + WRITE_PLAN = WritePlan + + +__all__ = [ + "ActionType", + "Action", + "ActionOutput", +] diff --git a/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/action-checkpoint.py b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/action-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbaaa84cdf79f23845777be25644b580caf399f --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/action-checkpoint.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 14:43 +@Author : alexanderwu +@File : action.py +""" + +from __future__ import annotations + +from typing import Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from metagpt_yusin.actions.action_node import ActionNode +from metagpt_yusin.context_mixin import ContextMixin +from metagpt_yusin.schema import ( + CodePlanAndChangeContext, + CodeSummarizeContext, + CodingContext, + RunCodeContext, + SerializationMixin, + TestingContext, +) +from metagpt_yusin.utils.project_repo import ProjectRepo + + +class Action(SerializationMixin, ContextMixin, BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "" + i_context: Union[ + dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, CodePlanAndChangeContext, str, None + ] = "" + prefix: str = "" # aask*时会加上prefix,作为system_message + desc: str = "" # for skill manager + node: ActionNode = Field(default=None, exclude=True) + + @property + def repo(self) -> ProjectRepo: + if not self.context.repo: + self.context.repo = ProjectRepo(self.context.git_repo) + return self.context.repo + + @property + def prompt_schema(self): + return self.config.prompt_schema + + @property + def project_name(self): + return self.config.project_name + + @project_name.setter + def project_name(self, value): + self.config.project_name = value + + @property + def project_path(self): + return self.config.project_path + + @model_validator(mode="before") + @classmethod + def set_name_if_empty(cls, values): + if "name" not in values or not values["name"]: + values["name"] = cls.__name__ + return values + + @model_validator(mode="before") + @classmethod + def _init_with_instruction(cls, values): + if "instruction" in values: + name = values["name"] + i = values.pop("instruction") + values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw") + return values + + def set_prefix(self, prefix): + """Set prefix for later usage""" + self.prefix = prefix + self.llm.system_prompt = prefix + if self.node: + self.node.llm = self.llm + return self + + def __str__(self): + return self.__class__.__name__ + + def __repr__(self): + return self.__str__() + + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: + """Append default prefix""" + return await self.llm.aask(prompt, system_msgs) + + async def _run_action_node(self, *args, **kwargs): + """Run action node""" + msgs = args[0] + context = "## History Messages\n" + context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))]) + return await self.node.fill(context=context, llm=self.llm) + + async def run(self, *args, **kwargs): + """Run action""" + if self.node: + return await self._run_action_node(*args, **kwargs) + raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/execute_task-checkpoint.py b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/execute_task-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..51be4d144fbcc11cb794268d0048cc3072184dd5 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/execute_task-checkpoint.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 12:26 +@Author : femto Zheng +@File : execute_task.py +""" + + +from metagpt_yusin.actions import Action +from metagpt_yusin.schema import Message + + +class ExecuteTask(Action): + name: str = "ExecuteTask" + i_context: list[Message] = [] + + async def run(self, *args, **kwargs): + pass diff --git a/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/generate_questions-checkpoint.py b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/generate_questions-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..fc25c1360f9cc48ef2f505a2cf33d67d9930bfa4 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/generate_questions-checkpoint.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@File : generate_questions.py +""" +from metagpt_yusin.actions import Action +from metagpt_yusin.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) + + +class GenerateQuestions(Action): + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" + + name: str = "GenerateQuestions" + + async def run(self, context) -> ActionNode: + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/talk_action-checkpoint.py b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/talk_action-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..37044660caaaac0306a072af4048fa992becfe66 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/.ipynb_checkpoints/talk_action-checkpoint.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : talk_action.py +@Desc : Act as it’s a talk +""" +from typing import Optional + +from metagpt_yusin.actions import Action +from metagpt_yusin.config2 import config +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message + + +class TalkAction(Action): + i_context: str + history_summary: str = "" + knowledge: str = "" + rsp: Optional[Message] = None + + @property + def agent_description(self): + return self.context.kwargs.agent_description + + @property + def language(self): + return self.context.kwargs.language or config.language + + @property + def prompt(self): + prompt = "" + if self.agent_description: + prompt = ( + f"You are {self.agent_description}. Your responses should align with the role-play agreement, " + f"maintaining the character's persona and habits. When faced with unrelated questions, playfully " + f"decline to answer without revealing your AI nature to preserve the character's image.\n\n" + ) + prompt += f"Knowledge:\n{self.knowledge}\n\n" if self.knowledge else "" + prompt += f"{self.history_summary}\n\n" + prompt += ( + "If the information is insufficient, you can search in the historical conversation or knowledge above.\n" + ) + language = self.language + prompt += ( + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n " + f"{self.i_context}" + ) + logger.debug(f"PROMPT: {prompt}") + return prompt + + @property + def prompt_gpt4(self): + kvs = { + "{role}": self.agent_description or "", + "{history}": self.history_summary or "", + "{knowledge}": self.knowledge or "", + "{language}": self.language, + "{ask}": self.i_context, + } + prompt = TalkActionPrompt.FORMATION_LOOSE + for k, v in kvs.items(): + prompt = prompt.replace(k, v) + logger.info(f"PROMPT: {prompt}") + return prompt + + # async def run_old(self, *args, **kwargs) -> ActionOutput: + # prompt = self.prompt + # rsp = await self.llm.aask(msg=prompt, system_msgs=[]) + # logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n") + # self._rsp = ActionOutput(content=rsp) + # return self._rsp + + @property + def aask_args(self): + language = self.language + system_msgs = [ + f"You are {self.agent_description}.", + "Your responses should align with the role-play agreement, " + "maintaining the character's persona and habits. When faced with unrelated questions, playfully " + "decline to answer without revealing your AI nature to preserve the character's image.", + "If the information is insufficient, you can search in the context or knowledge.", + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.", + ] + format_msgs = [] + if self.knowledge: + format_msgs.append({"role": "assistant", "content": self.knowledge}) + if self.history_summary: + format_msgs.append({"role": "assistant", "content": self.history_summary}) + return self.i_context, format_msgs, system_msgs + + async def run(self, with_message=None, **kwargs) -> Message: + msg, format_msgs, system_msgs = self.aask_args + rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False) + self.rsp = Message(content=rsp, role="assistant", cause_by=self) + return self.rsp + + +class TalkActionPrompt: + FORMATION = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "[ASK_BEGIN]" and [ASK_END] tags enclose the questions; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should align with the role-play agreement, maintaining the + character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing + your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" + + FORMATION_LOOSE = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions +, playfully decline to answer without revealing your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" diff --git a/notebook_dir/metagpt_yusin/actions/__init__.py b/notebook_dir/metagpt_yusin/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a15460b2d42de6a94d77a4d09adf4b7ed9404b0b --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/__init__.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:44 +@Author : alexanderwu +@File : __init__.py +""" +from enum import Enum + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.actions.action_output import ActionOutput +from metagpt_yusin.actions.add_requirement import UserRequirement +from metagpt_yusin.actions.debug_error import DebugError +from metagpt_yusin.actions.design_api import WriteDesign +from metagpt_yusin.actions.design_api_review import DesignReview +from metagpt_yusin.actions.project_management import WriteTasks +from metagpt_yusin.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch +from metagpt_yusin.actions.run_code import RunCode +from metagpt_yusin.actions.search_and_summarize import SearchAndSummarize +from metagpt_yusin.actions.write_code import WriteCode +from metagpt_yusin.actions.write_code_review import WriteCodeReview +from metagpt_yusin.actions.write_prd import WritePRD +from metagpt_yusin.actions.write_prd_review import WritePRDReview +from metagpt_yusin.actions.write_test import WriteTest +from metagpt_yusin.actions.di.execute_nb_code import ExecuteNbCode +from metagpt_yusin.actions.di.write_analysis_code import WriteAnalysisCode +from metagpt_yusin.actions.di.write_plan import WritePlan + + +class ActionType(Enum): + """All types of Actions, used for indexing.""" + + ADD_REQUIREMENT = UserRequirement + WRITE_PRD = WritePRD + WRITE_PRD_REVIEW = WritePRDReview + WRITE_DESIGN = WriteDesign + DESIGN_REVIEW = DesignReview + WRTIE_CODE = WriteCode + WRITE_CODE_REVIEW = WriteCodeReview + WRITE_TEST = WriteTest + RUN_CODE = RunCode + DEBUG_ERROR = DebugError + WRITE_TASKS = WriteTasks + SEARCH_AND_SUMMARIZE = SearchAndSummarize + COLLECT_LINKS = CollectLinks + WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize + CONDUCT_RESEARCH = ConductResearch + EXECUTE_NB_CODE = ExecuteNbCode + WRITE_ANALYSIS_CODE = WriteAnalysisCode + WRITE_PLAN = WritePlan + + +__all__ = [ + "ActionType", + "Action", + "ActionOutput", +] diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80a8b6291b2474e12fa88bdaa6f0df5cf39b0ec9 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/action.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/action.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a48b38b00fd7f321108324a1cd8236aeecf7084 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/action.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/action_node.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/action_node.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9ef89ab21d10e660b7c4730f6ff9f7bda9e32e9 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/action_node.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/action_outcls_registry.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/action_outcls_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce486a3f2aed8cad9c11e2ca6787342e49a69094 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/action_outcls_registry.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/action_output.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/action_output.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb61233a67e3a3b006886ec23efb1f77e1867087 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/action_output.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/add_requirement.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/add_requirement.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f21b3757bcbb887af194043759e3c2eb53c775f Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/add_requirement.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/debug_error.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/debug_error.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9bfdedcc5c9bd87657fb3781941917a5fb03173 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/debug_error.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/design_api.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/design_api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..332ae247a0e4033bcbafa81712297e374f89b7cf Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/design_api.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/design_api_an.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/design_api_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c7d481058e5d7a34611d9bc51f69715fc913115 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/design_api_an.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/design_api_review.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/design_api_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b245199e09fe0689ff72db6b996f0451156fbd4 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/design_api_review.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/fix_bug.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/fix_bug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc94372cae07d962b430ef714ed1517d748b502b Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/fix_bug.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/prepare_documents.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/prepare_documents.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f5525269a6544c0c5056ea2cc668975362bb81e Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/prepare_documents.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/project_management.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/project_management.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdd0499f1a186852fd344cc2fb8adaba073d0028 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/project_management.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/project_management_an.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/project_management_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..560b43bedba7bec3c511dce3dec82c5df6b55dc6 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/project_management_an.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/research.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/research.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc927801829add29d70b1b5e6b2ab8b930bcd97d Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/research.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/run_code.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/run_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94aa52e6a750abdfc6212df884de757165aa5f2f Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/run_code.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/search_and_summarize.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/search_and_summarize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5ca4fb78a7063fecf217d75a7bcedc9700aea72 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/search_and_summarize.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/summarize_code.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/summarize_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8333971f25bf32e669d6356a9632bd20c61d7a00 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/summarize_code.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_code.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d31d2b63cf42942224f35d9e83973c7d78effafc Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_code.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_code_plan_and_change_an.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_code_plan_and_change_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c467d66e581221debf399a5df1f11418a46b0c1 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_code_plan_and_change_an.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_code_review.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_code_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..379361950778951860715391762f3563be701a2e Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_code_review.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e77895f42a926eb4b727b9885808cb51886bd8f6 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd_an.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95159b49d175cddfcf990d05629de6fe026ae423 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd_an.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd_review.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88e0625751c6dd7af7e0edf9227c32471b6a07d1 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_prd_review.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/__pycache__/write_test.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/__pycache__/write_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff1e0dca38cbb33e89bf10d855895d001deeea62 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/__pycache__/write_test.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/action.py b/notebook_dir/metagpt_yusin/actions/action.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbaaa84cdf79f23845777be25644b580caf399f --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/action.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 14:43 +@Author : alexanderwu +@File : action.py +""" + +from __future__ import annotations + +from typing import Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from metagpt_yusin.actions.action_node import ActionNode +from metagpt_yusin.context_mixin import ContextMixin +from metagpt_yusin.schema import ( + CodePlanAndChangeContext, + CodeSummarizeContext, + CodingContext, + RunCodeContext, + SerializationMixin, + TestingContext, +) +from metagpt_yusin.utils.project_repo import ProjectRepo + + +class Action(SerializationMixin, ContextMixin, BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "" + i_context: Union[ + dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, CodePlanAndChangeContext, str, None + ] = "" + prefix: str = "" # aask*时会加上prefix,作为system_message + desc: str = "" # for skill manager + node: ActionNode = Field(default=None, exclude=True) + + @property + def repo(self) -> ProjectRepo: + if not self.context.repo: + self.context.repo = ProjectRepo(self.context.git_repo) + return self.context.repo + + @property + def prompt_schema(self): + return self.config.prompt_schema + + @property + def project_name(self): + return self.config.project_name + + @project_name.setter + def project_name(self, value): + self.config.project_name = value + + @property + def project_path(self): + return self.config.project_path + + @model_validator(mode="before") + @classmethod + def set_name_if_empty(cls, values): + if "name" not in values or not values["name"]: + values["name"] = cls.__name__ + return values + + @model_validator(mode="before") + @classmethod + def _init_with_instruction(cls, values): + if "instruction" in values: + name = values["name"] + i = values.pop("instruction") + values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw") + return values + + def set_prefix(self, prefix): + """Set prefix for later usage""" + self.prefix = prefix + self.llm.system_prompt = prefix + if self.node: + self.node.llm = self.llm + return self + + def __str__(self): + return self.__class__.__name__ + + def __repr__(self): + return self.__str__() + + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: + """Append default prefix""" + return await self.llm.aask(prompt, system_msgs) + + async def _run_action_node(self, *args, **kwargs): + """Run action node""" + msgs = args[0] + context = "## History Messages\n" + context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))]) + return await self.node.fill(context=context, llm=self.llm) + + async def run(self, *args, **kwargs): + """Run action""" + if self.node: + return await self._run_action_node(*args, **kwargs) + raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/notebook_dir/metagpt_yusin/actions/action_graph.py b/notebook_dir/metagpt_yusin/actions/action_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..893bc6d4c27c5b619b8a86797bba8bc4927879ea --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/action_graph.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/30 13:52 +@Author : alexanderwu +@File : action_graph.py +""" +from __future__ import annotations + +# from metagpt.actions.action_node import ActionNode + + +class ActionGraph: + """ActionGraph: a directed graph to represent the dependency between actions.""" + + def __init__(self): + self.nodes = {} + self.edges = {} + self.execution_order = [] + + def add_node(self, node): + """Add a node to the graph""" + self.nodes[node.key] = node + + def add_edge(self, from_node: "ActionNode", to_node: "ActionNode"): + """Add an edge to the graph""" + if from_node.key not in self.edges: + self.edges[from_node.key] = [] + self.edges[from_node.key].append(to_node.key) + from_node.add_next(to_node) + to_node.add_prev(from_node) + + def topological_sort(self): + """Topological sort the graph""" + visited = set() + stack = [] + + def visit(k): + if k not in visited: + visited.add(k) + if k in self.edges: + for next_node in self.edges[k]: + visit(next_node) + stack.insert(0, k) + + for key in self.nodes: + visit(key) + + self.execution_order = stack diff --git a/notebook_dir/metagpt_yusin/actions/action_node.py b/notebook_dir/metagpt_yusin/actions/action_node.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2a0f5068e948414c01c9404a1e9a6b1525bf62 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/action_node.py @@ -0,0 +1,719 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/11 18:45 +@Author : alexanderwu +@File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. +""" +import json +import typing +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from pydantic import BaseModel, Field, create_model, model_validator +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt_yusin.actions.action_outcls_registry import register_action_outcls +from metagpt_yusin.const import USE_CONFIG_TIMEOUT +from metagpt_yusin.llm import BaseLLM +from metagpt_yusin.logs import logger +from metagpt_yusin.provider.postprocess.llm_output_postprocess import llm_output_postprocess +from metagpt_yusin.utils.common import OutputParser, general_after_log +from metagpt_yusin.utils.human_interaction import HumanInteraction + + +class ReviewMode(Enum): + HUMAN = "human" + AUTO = "auto" + + +class ReviseMode(Enum): + HUMAN = "human" # human revise + HUMAN_REVIEW = "human_review" # human-review and auto-revise + AUTO = "auto" # auto-review and auto-revise + + +TAG = "CONTENT" + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." + +SIMPLE_TEMPLATE = """ +## context +{context} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow instructions of nodes, generate output and make sure it follows the format example. +""" + +REVIEW_TEMPLATE = """ +## context +Compare the key's value of nodes_output and the corresponding requirements one by one. If a key's value that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys. + +### nodes_output +{nodes_output} + +----- + +## format example +[{tag}] +{{ + "key1": "comment1", + "key2": "comment2", + "keyn": "commentn" +}} +[/{tag}] + +## nodes: ": # " +- key1: # the first key name of mismatch key +- key2: # the second key name of mismatch key +- keyn: # the last key name of mismatch key + +## constraint +{constraint} + +## action +Follow format example's {prompt_schema} format, generate output and make sure it follows the format example. +""" + +REVISE_TEMPLATE = """ +## context +change the nodes_output key's value to meet its comment and no need to add extra comment. + +### nodes_output +{nodes_output} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow format example's {prompt_schema} format, generate output and make sure it follows the format example. +""" + + +def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): + markdown_str = "" + for key, value in d.items(): + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" + return markdown_str + + +class ActionNode: + """ActionNode is a tree of nodes.""" + + schema: str # raw/json/markdown, default: "" + + # Action Context + context: str # all the context, including all necessary info + llm: BaseLLM # LLM with aask interface + children: dict[str, "ActionNode"] + + # Action Input + key: str # Product Requirement / File list / Code + func: typing.Callable # 与节点相关联的函数或LLM调用 + params: Dict[str, Type] # 输入参数的字典,键为参数名,值为参数类型 + expected_type: Type # such as str / int / float etc. + # context: str # everything in the history. + instruction: str # the instructions should be followed. + example: Any # example for In Context-Learning. + + # Action Output + content: str + instruct_content: BaseModel + + # For ActionGraph + prevs: List["ActionNode"] # previous nodes + nexts: List["ActionNode"] # next nodes + + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: Any, + content: str = "", + children: dict[str, "ActionNode"] = None, + schema: str = "", + ): + self.key = key + self.expected_type = expected_type + self.instruction = instruction + self.example = example + self.content = content + self.children = children if children is not None else {} + self.schema = schema + self.prevs = [] + self.nexts = [] + + def __str__(self): + return ( + f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" + f", {self.content}, {self.children}" + ) + + def __repr__(self): + return self.__str__() + + def add_prev(self, node: "ActionNode"): + """增加前置ActionNode""" + self.prevs.append(node) + + def add_next(self, node: "ActionNode"): + """增加后置ActionNode""" + self.nexts.append(node) + + def add_child(self, node: "ActionNode"): + """增加子ActionNode""" + self.children[node.key] = node + + def get_child(self, key: str) -> Union["ActionNode", None]: + return self.children.get(key, None) + + def add_children(self, nodes: List["ActionNode"]): + """批量增加子ActionNode""" + for node in nodes: + self.add_child(node) + + @classmethod + def from_children(cls, key, nodes: List["ActionNode"]): + """直接从一系列的子nodes初始化""" + obj = cls(key, str, "", "") + obj.add_children(nodes) + return obj + + def _get_children_mapping(self, exclude=None) -> Dict[str, Any]: + """获得子ActionNode的字典,以key索引,支持多级结构。""" + exclude = exclude or [] + + def _get_mapping(node: "ActionNode") -> Dict[str, Any]: + mapping = {} + for key, child in node.children.items(): + if key in exclude: + continue + # 对于嵌套的子节点,递归调用 _get_mapping + if child.children: + mapping[key] = _get_mapping(child) + else: + mapping[key] = (child.expected_type, Field(default=child.example, description=child.instruction)) + return mapping + + return _get_mapping(self) + + def _get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: + """get self key: type mapping""" + return {self.key: (self.expected_type, ...)} + + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: + """get key: type mapping under mode""" + if mode == "children" or (mode == "auto" and self.children): + return self._get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self._get_self_mapping() + + @classmethod + @register_action_outcls + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): + """基于pydantic v2的模型动态生成,用来检验结果类型正确性""" + + def check_fields(cls, values): + required_fields = set(mapping.keys()) + missing_fields = required_fields - set(values.keys()) + if missing_fields: + raise ValueError(f"Missing fields: {missing_fields}") + + unrecognized_fields = set(values.keys()) - required_fields + if unrecognized_fields: + logger.warning(f"Unrecognized fields: {unrecognized_fields}") + return values + + validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} + + new_fields = {} + for field_name, field_value in mapping.items(): + if isinstance(field_value, dict): + # 对于嵌套结构,递归创建模型类 + nested_class_name = f"{class_name}_{field_name}" + nested_class = cls.create_model_class(nested_class_name, field_value) + new_fields[field_name] = (nested_class, ...) + else: + new_fields[field_name] = field_value + + new_class = create_model(class_name, __validators__=validators, **new_fields) + return new_class + + def create_class(self, mode: str = "auto", class_name: str = None, exclude=None): + class_name = class_name if class_name else f"{self.key}_AN" + mapping = self.get_mapping(mode=mode, exclude=exclude) + return self.create_model_class(class_name, mapping) + + def _create_children_class(self, exclude=None): + """使用object内有的字段直接生成model_class""" + class_name = f"{self.key}_AN" + mapping = self._get_children_mapping(exclude=exclude) + return self.create_model_class(class_name, mapping) + + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + nodes = self._to_dict(format_func=format_func, mode=mode, exclude=exclude) + if not isinstance(nodes, dict): + nodes = {self.key: nodes} + return nodes + + def _to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + + # 如果没有提供格式化函数,则使用默认的格式化函数 + if format_func is None: + format_func = lambda node: node.instruction + + # 使用提供的格式化函数来格式化当前节点的值 + formatted_value = format_func(self) + + # 创建当前节点的键值对 + if (mode == "children" or mode == "auto") and self.children: + node_value = {} + else: + node_value = formatted_value + + if mode == "root": + return {self.key: node_value} + + # 递归处理子节点 + exclude = exclude or [] + for child_key, child_node in self.children.items(): + if child_key in exclude: + continue + # 递归调用 to_dict 方法并更新节点字典 + child_dict = child_node._to_dict(format_func, mode, exclude) + node_value[child_key] = child_dict + + return node_value + + def update_instruct_content(self, incre_data: dict[str, Any]): + assert self.instruct_content + origin_sc_dict = self.instruct_content.model_dump() + origin_sc_dict.update(incre_data) + output_class = self.create_class() + self.instruct_content = output_class(**origin_sc_dict) + + def keys(self, mode: str = "auto") -> list: + if mode == "children" or (mode == "auto" and self.children): + keys = [] + else: + keys = [self.key] + if mode == "root": + return keys + + for _, child_node in self.children.items(): + keys.append(child_node.key) + return keys + + def compile_to(self, i: Dict, schema, kv_sep) -> str: + if schema == "json": + return json.dumps(i, indent=4, ensure_ascii=False) + elif schema == "markdown": + return dict_to_markdown(i, kv_sep=kv_sep) + else: + return str(i) + + def tagging(self, text, schema, tag="") -> str: + if not tag: + return text + return f"[{tag}]\n{text}\n[/{tag}]" + + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) + text = self.compile_to(nodes, schema, kv_sep) + return self.tagging(text, schema, tag) + + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown template with all/root/children nodes""" + format_func = lambda i: f"{i.expected_type} # {i.instruction}" + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) + + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown examples with all/root/children nodes""" + + # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example + # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str + format_func = lambda i: i.example + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) + + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: + """ + mode: all/root/children + mode="children": 编译所有子节点为一个统一模板,包括instruction与example + mode="all": NotImplemented + mode="root": NotImplemented + schmea: raw/json/markdown + schema="raw": 不编译,context, lang_constaint, instruction + schema="json":编译context, example(json), instruction(markdown), constraint, action + schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action + """ + if schema == "raw": + return f"{context}\n\n## Actions\n{LANGUAGE_CONSTRAINT}\n{self.instruction}" + + ### 直接使用 pydantic BaseModel 生成 instruction 与 example,仅限 JSON + # child_class = self._create_children_class() + # node_schema = child_class.model_json_schema() + # defaults = { + # k: str(v) + # for k, v in child_class.model_fields.items() + # if k not in exclude + # } + # instruction = node_schema + # example = json.dumps(defaults, indent=4) + + # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + + prompt = template.format( + context=context, + example=example, + instruction=instruction, + constraint=constraint, + ) + return prompt + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _aask_v1( + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + images: Optional[Union[str, list[str]]] = None, + system_msgs: Optional[list[str]] = None, + schema="markdown", # compatible to original format + timeout=USE_CONFIG_TIMEOUT, + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + + if schema == "json": + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + else: # using markdown parser + parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) + + logger.debug(f"parsed_data:\n{parsed_data}") + instruct_content = output_class(**parsed_data) + return content, instruct_content + + def get(self, key): + return self.instruct_content.model_dump()[key] + + def set_recursive(self, name, value): + setattr(self, name, value) + for _, i in self.children.items(): + i.set_recursive(name, value) + + def set_llm(self, llm): + self.set_recursive("llm", llm) + + def set_context(self, context): + self.set_recursive("context", context) + + async def simple_fill( + self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None + ): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) + if schema != "raw": + mapping = self.get_mapping(mode, exclude=exclude) + class_name = f"{self.key}_AN" + content, scontent = await self._aask_v1( + prompt, class_name, mapping, images=images, schema=schema, timeout=timeout + ) + self.content = content + self.instruct_content = scontent + else: + self.content = await self.llm.aask(prompt) + self.instruct_content = None + + return self + + async def fill( + self, + context, + llm, + schema="json", + mode="auto", + strgy="simple", + images: Optional[Union[str, list[str]]] = None, + timeout=USE_CONFIG_TIMEOUT, + exclude=[], + ): + """Fill the node(s) with mode. + + :param context: Everything we should know when filling node. + :param llm: Large Language Model with pre-defined system message. + :param schema: json/markdown, determine example and output format. + - raw: free form text + - json: it's easy to open source LLM with json format + - markdown: when generating code, markdown is always better + :param mode: auto/children/root + - auto: automated fill children's nodes and gather outputs, if no children, fill itself + - children: fill children's nodes and gather outputs + - root: fill root's node and gather output + :param strgy: simple/complex + - simple: run only once + - complex: run each node + :param images: the list of image url or base64 for gpt4-v + :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. + :return: self + """ + self.set_llm(llm) + self.set_context(context) + if self.schema: + schema = self.schema + + if strgy == "simple": + return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) + elif strgy == "complex": + # 这里隐式假设了拥有children + tmp = {} + for _, i in self.children.items(): + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) + tmp.update(child.instruct_content.model_dump()) + cls = self._create_children_class() + self.instruct_content = cls(**tmp) + return self + + async def human_review(self) -> dict[str, str]: + review_comments = HumanInteraction().interact_with_instruct_content( + instruct_content=self.instruct_content, interact_type="review" + ) + + return review_comments + + def _makeup_nodes_output_with_req(self) -> dict[str, str]: + instruct_content_dict = self.instruct_content.model_dump() + nodes_output = {} + for key, value in instruct_content_dict.items(): + child = self.get_child(key) + nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction} + return nodes_output + + async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]: + """use key's output value and its instruction to review the modification comment""" + nodes_output = self._makeup_nodes_output_with_req() + """nodes_output format: + { + "key": {"value": "output value", "requirement": "key instruction"} + } + """ + if not nodes_output: + return dict() + + prompt = template.format( + nodes_output=json.dumps(nodes_output, ensure_ascii=False), + tag=TAG, + constraint=FORMAT_CONSTRAINT, + prompt_schema="json", + ) + + content = await self.llm.aask(prompt) + # Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys + # of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class. + keys = self.keys() + include_keys = [] + for key in keys: + if f'"{key}":' in content: + include_keys.append(key) + if not include_keys: + return dict() + + exclude_keys = list(set(keys).difference(include_keys)) + output_class_name = f"{self.key}_AN_REVIEW" + output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys) + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + instruct_content = output_class(**parsed_data) + return instruct_content.model_dump() + + async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO): + # generate review comments + if review_mode == ReviewMode.HUMAN: + review_comments = await self.human_review() + else: + review_comments = await self.auto_review() + + if not review_comments: + logger.warning("There are no review comments") + return review_comments + + async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO): + """only give the review comment of each exist and mismatch key + + :param strgy: simple/complex + - simple: run only once + - complex: run each node + """ + if not hasattr(self, "llm"): + raise RuntimeError("use `review` after `fill`") + assert review_mode in ReviewMode + assert self.instruct_content, 'review only support with `schema != "raw"`' + + if strgy == "simple": + review_comments = await self.simple_review(review_mode) + elif strgy == "complex": + # review each child node one-by-one + review_comments = {} + for _, child in self.children.items(): + child_review_comment = await child.simple_review(review_mode) + review_comments.update(child_review_comment) + + return review_comments + + async def human_revise(self) -> dict[str, str]: + review_contents = HumanInteraction().interact_with_instruct_content( + instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise" + ) + # re-fill the ActionNode + self.update_instruct_content(review_contents) + return review_contents + + def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]: + instruct_content_dict = self.instruct_content.model_dump() + nodes_output = {} + for key, value in instruct_content_dict.items(): + if key in review_comments: + nodes_output[key] = {"value": value, "comment": review_comments[key]} + return nodes_output + + async def auto_revise( + self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE + ) -> dict[str, str]: + """revise the value of incorrect keys""" + # generate review comments + if revise_mode == ReviseMode.AUTO: + review_comments: dict = await self.auto_review() + elif revise_mode == ReviseMode.HUMAN_REVIEW: + review_comments: dict = await self.human_review() + + include_keys = list(review_comments.keys()) + + # generate revise content, two-steps + # step1, find the needed revise keys from review comments to makeup prompt template + nodes_output = self._makeup_nodes_output_with_comment(review_comments) + keys = self.keys() + exclude_keys = list(set(keys).difference(include_keys)) + example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys) + instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys) + + prompt = template.format( + nodes_output=json.dumps(nodes_output, ensure_ascii=False), + example=example, + instruction=instruction, + constraint=FORMAT_CONSTRAINT, + prompt_schema="json", + ) + + # step2, use `_aask_v1` to get revise structure result + output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys) + output_class_name = f"{self.key}_AN_REVISE" + content, scontent = await self._aask_v1( + prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json" + ) + + # re-fill the ActionNode + sc_dict = scontent.model_dump() + self.update_instruct_content(sc_dict) + return sc_dict + + async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]: + if revise_mode == ReviseMode.HUMAN: + revise_contents = await self.human_revise() + else: + revise_contents = await self.auto_revise(revise_mode) + + return revise_contents + + async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]: + """revise the content of ActionNode and update the instruct_content + + :param strgy: simple/complex + - simple: run only once + - complex: run each node + """ + if not hasattr(self, "llm"): + raise RuntimeError("use `revise` after `fill`") + assert revise_mode in ReviseMode + assert self.instruct_content, 'revise only support with `schema != "raw"`' + + if strgy == "simple": + revise_contents = await self.simple_revise(revise_mode) + elif strgy == "complex": + # revise each child node one-by-one + revise_contents = {} + for _, child in self.children.items(): + child_revise_content = await child.simple_revise(revise_mode) + revise_contents.update(child_revise_content) + self.update_instruct_content(revise_contents) + + return revise_contents + + @classmethod + def from_pydantic(cls, model: Type[BaseModel], key: str = None): + """ + Creates an ActionNode tree from a Pydantic model. + + Args: + model (Type[BaseModel]): The Pydantic model to convert. + + Returns: + ActionNode: The root node of the created ActionNode tree. + """ + key = key or model.__name__ + root_node = cls(key=key, expected_type=Type[model], instruction="", example="") + + for field_name, field_info in model.model_fields.items(): + field_type = field_info.annotation + description = field_info.description + default = field_info.default + + # Recursively handle nested models if needed + if not isinstance(field_type, typing._GenericAlias) and issubclass(field_type, BaseModel): + child_node = cls.from_pydantic(field_type, key=field_name) + else: + child_node = cls(key=field_name, expected_type=field_type, instruction=description, example=default) + + root_node.add_child(child_node) + + return root_node diff --git a/notebook_dir/metagpt_yusin/actions/action_outcls_registry.py b/notebook_dir/metagpt_yusin/actions/action_outcls_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6baa4cea926a80251ace6ddfc28f745482bddcdf --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/action_outcls_registry.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class +# with same class name and mapping + +from functools import wraps + +action_outcls_registry = dict() + + +def register_action_outcls(func): + """ + Due to `create_model` return different Class even they have same class name and mapping. + In order to do a comparison, use outcls_id to identify same Class with same class name and field definition + """ + + @wraps(func) + def decorater(*args, **kwargs): + """ + arr example + [, 'test', {'field': (str, Ellipsis)}] + """ + arr = list(args) + list(kwargs.values()) + """ + outcls_id example + "_test_{'field': (str, Ellipsis)}" + """ + for idx, item in enumerate(arr): + if isinstance(item, dict): + arr[idx] = dict(sorted(item.items())) + outcls_id = "_".join([str(i) for i in arr]) + # eliminate typing influence + outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict") + + if outcls_id in action_outcls_registry: + return action_outcls_registry[outcls_id] + + out_cls = func(*args, **kwargs) + action_outcls_registry[outcls_id] = out_cls + return out_cls + + return decorater diff --git a/notebook_dir/metagpt_yusin/actions/action_output.py b/notebook_dir/metagpt_yusin/actions/action_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6be8dac50e4fd6f6bfeb37aab23b3405d6b18814 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/action_output.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding: utf-8 +""" +@Time : 2023/7/11 10:03 +@Author : chengmaoyu +@File : action_output +""" + +from pydantic import BaseModel + + +class ActionOutput: + content: str + instruct_content: BaseModel + + def __init__(self, content: str, instruct_content: BaseModel): + self.content = content + self.instruct_content = instruct_content diff --git a/notebook_dir/metagpt_yusin/actions/add_requirement.py b/notebook_dir/metagpt_yusin/actions/add_requirement.py new file mode 100644 index 0000000000000000000000000000000000000000..e0dfb495fc7544f016541aac910dd552c4c96b8b --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/add_requirement.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/20 17:46 +@Author : alexanderwu +@File : add_requirement.py +""" +from metagpt_yusin.actions import Action + + +class UserRequirement(Action): + """User Requirement without any implementation details""" diff --git a/notebook_dir/metagpt_yusin/actions/debug_error.py b/notebook_dir/metagpt_yusin/actions/debug_error.py new file mode 100644 index 0000000000000000000000000000000000000000..faf628264dae13ae4ac3b9b846822cd66cab0340 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/debug_error.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:46 +@Author : alexanderwu +@File : debug_error.py +@Modified By: mashenquan, 2023/11/27. + 1. Divide the context into three components: legacy code, unit test code, and console log. + 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. +""" +import re + +from pydantic import Field + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import RunCodeContext, RunCodeResult +from metagpt_yusin.utils.common import CodeParser + +PROMPT_TEMPLATE = """ +NOTICE +1. Role: You are a Development Engineer or QA engineer; +2. Task: You received this message from another Development Engineer or QA engineer who ran or tested your code. +Based on the message, first, figure out your own role, i.e. Engineer or QaEngineer, +then rewrite the development code or the test code based on your role, the error, and the summary, such that all bugs are fixed and the code performs well. +Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the test case or script and triple quotes. +The message is as follows: +# Legacy Code +```python +{code} +``` +--- +# Unit Test Code +```python +{test_code} +``` +--- +# Console logs +```text +{logs} +``` +--- +Now you should start rewriting the code: +## file name of the code to rewrite: Write code with triple quote. Do your best to implement THIS IN ONLY ONE FILE. +""" + + +class DebugError(Action): + i_context: RunCodeContext = Field(default_factory=RunCodeContext) + + async def run(self, *args, **kwargs) -> str: + output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename) + if not output_doc: + return "" + output_detail = RunCodeResult.loads(output_doc.content) + pattern = r"Ran (\d+) tests in ([\d.]+)s\n\nOK" + matches = re.search(pattern, output_detail.stderr) + if matches: + return "" + + logger.info(f"Debug and rewrite {self.i_context.test_filename}") + code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get( + filename=self.i_context.code_filename + ) + if not code_doc: + return "" + test_doc = await self.repo.tests.get(filename=self.i_context.test_filename) + if not test_doc: + return "" + prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr) + + rsp = await self._aask(prompt) + code = CodeParser.parse_code(block="", text=rsp) + + return code diff --git a/notebook_dir/metagpt_yusin/actions/design_api.py b/notebook_dir/metagpt_yusin/actions/design_api.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6f6ecb70d4b6365afd66855bf1709be6fc7789 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/design_api.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 19:26 +@Author : alexanderwu +@File : design_api.py +@Modified By: mashenquan, 2023/11/27. + 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. + 2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +""" +import json +from pathlib import Path +from typing import Optional + +from metagpt_yusin.actions import Action, ActionOutput +from metagpt_yusin.actions.design_api_an import ( + DATA_STRUCTURES_AND_INTERFACES, + DESIGN_API_NODE, + PROGRAM_CALL_FLOW, + REFINED_DATA_STRUCTURES_AND_INTERFACES, + REFINED_DESIGN_NODE, + REFINED_PROGRAM_CALL_FLOW, +) +from metagpt_yusin.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Document, Documents, Message +from metagpt_yusin.utils.mermaid import mermaid_to_file + +NEW_REQ_TEMPLATE = """ +### Legacy Content +{old_design} + +### New Requirements +{context} +""" + + +class WriteDesign(Action): + name: str = "" + i_context: Optional[str] = None + desc: str = ( + "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." + ) + + async def run(self, with_messages: Message, schema: str = None): + # Use `git status` to identify which PRD documents have been modified in the `docs/prd` directory. + changed_prds = self.repo.docs.prd.changed_files + # Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone + # changes. + changed_system_designs = self.repo.docs.system_design.changed_files + + # For those PRDs and design documents that have undergone changes, regenerate the design content. + changed_files = Documents() + for filename in changed_prds.keys(): + doc = await self._update_system_design(filename=filename) + changed_files.docs[filename] = doc + + for filename in changed_system_designs.keys(): + if filename in changed_files.docs: + continue + doc = await self._update_system_design(filename=filename) + changed_files.docs[filename] = doc + if not changed_files.docs: + logger.info("Nothing has changed.") + # Wait until all files under `docs/system_designs/` are processed before sending the publish message, + # leaving room for global optimization in subsequent steps. + return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files) + + async def _new_system_design(self, context): + node = await DESIGN_API_NODE.fill(context=context, llm=self.llm) + return node + + async def _merge(self, prd_doc, system_design_doc): + context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) + node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm) + system_design_doc.content = node.instruct_content.model_dump_json() + return system_design_doc + + async def _update_system_design(self, filename) -> Document: + prd = await self.repo.docs.prd.get(filename) + old_system_design_doc = await self.repo.docs.system_design.get(filename) + if not old_system_design_doc: + system_design = await self._new_system_design(context=prd.content) + doc = await self.repo.docs.system_design.save( + filename=filename, + content=system_design.instruct_content.model_dump_json(), + dependencies={prd.root_relative_path}, + ) + else: + doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc) + await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path}) + await self._save_data_api_design(doc) + await self._save_seq_flow(doc) + await self.repo.resources.system_design.save_pdf(doc=doc) + return doc + + async def _save_data_api_design(self, design_doc): + m = json.loads(design_doc.content) + data_api_design = m.get(DATA_STRUCTURES_AND_INTERFACES.key) or m.get(REFINED_DATA_STRUCTURES_AND_INTERFACES.key) + if not data_api_design: + return + pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("") + await self._save_mermaid_file(data_api_design, pathname) + logger.info(f"Save class view to {str(pathname)}") + + async def _save_seq_flow(self, design_doc): + m = json.loads(design_doc.content) + seq_flow = m.get(PROGRAM_CALL_FLOW.key) or m.get(REFINED_PROGRAM_CALL_FLOW.key) + if not seq_flow: + return + pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("") + await self._save_mermaid_file(seq_flow, pathname) + logger.info(f"Saving sequence flow to {str(pathname)}") + + async def _save_mermaid_file(self, data: str, pathname: Path): + pathname.parent.mkdir(parents=True, exist_ok=True) + await mermaid_to_file(self.config.mermaid.engine, data, pathname) diff --git a/notebook_dir/metagpt_yusin/actions/design_api_an.py b/notebook_dir/metagpt_yusin/actions/design_api_an.py new file mode 100644 index 0000000000000000000000000000000000000000..9993850ecde65e9ace3b076987cbe89add4f5f2e --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/design_api_an.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/12 22:24 +@Author : alexanderwu +@File : design_api_an.py +""" +from typing import List + +from metagpt_yusin.actions.action_node import ActionNode +from metagpt_yusin.utils.mermaid import MMC1, MMC2 + +IMPLEMENTATION_APPROACH = ActionNode( + key="Implementation approach", + expected_type=str, + instruction="Analyze the difficult points of the requirements, select the appropriate open-source framework", + example="We will ...", +) + +REFINED_IMPLEMENTATION_APPROACH = ActionNode( + key="Refined Implementation Approach", + expected_type=str, + instruction="Update and extend the original implementation approach to reflect the evolving challenges and " + "requirements due to incremental development. Outline the steps involved in the implementation process with the " + "detailed strategies.", + example="We will refine ...", +) + +PROJECT_NAME = ActionNode( + key="Project name", expected_type=str, instruction="The project name with underline", example="game_2048" +) + +FILE_LIST = ActionNode( + key="File list", + expected_type=List[str], + instruction="Only need relative paths. ALWAYS write a main.py or app.py here", + example=["main.py", "game.py"], +) + +REFINED_FILE_LIST = ActionNode( + key="Refined File list", + expected_type=List[str], + instruction="Update and expand the original file list including only relative paths. Up to 2 files can be added." + "Ensure that the refined file list reflects the evolving structure of the project.", + example=["main.py", "game.py", "new_feature.py"], +) + +DATA_STRUCTURES_AND_INTERFACES = ActionNode( + key="Data structures and interfaces", + expected_type=str, + instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type" + " annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. " + "The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.", + example=MMC1, +) + +REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode( + key="Refined Data structures and interfaces", + expected_type=str, + instruction="Update and extend the existing mermaid classDiagram code syntax to incorporate new classes, " + "methods (including __init__), and functions with precise type annotations. Delineate additional " + "relationships between classes, ensuring clarity and adherence to PEP8 standards." + "Retain content that is not related to incremental development but important for consistency and clarity.", + example=MMC1, +) + +PROGRAM_CALL_FLOW = ActionNode( + key="Program call flow", + expected_type=str, + instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE " + "accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.", + example=MMC2, +) + +REFINED_PROGRAM_CALL_FLOW = ActionNode( + key="Refined Program call flow", + expected_type=str, + instruction="Extend the existing sequenceDiagram code syntax with detailed information, accurately covering the" + "CRUD and initialization of each object. Ensure correct syntax usage and reflect the incremental changes introduced" + "in the classes and API defined above. " + "Retain content that is not related to incremental development but important for consistency and clarity.", + example=MMC2, +) + +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention unclear project aspects, then try to clarify it.", + example="Clarification needed on third-party API integration, ...", +) + +NODES = [ + IMPLEMENTATION_APPROACH, + # PROJECT_NAME, + FILE_LIST, + DATA_STRUCTURES_AND_INTERFACES, + PROGRAM_CALL_FLOW, + ANYTHING_UNCLEAR, +] + +REFINED_NODES = [ + REFINED_IMPLEMENTATION_APPROACH, + REFINED_FILE_LIST, + REFINED_DATA_STRUCTURES_AND_INTERFACES, + REFINED_PROGRAM_CALL_FLOW, + ANYTHING_UNCLEAR, +] + +DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES) +REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES) diff --git a/notebook_dir/metagpt_yusin/actions/design_api_review.py b/notebook_dir/metagpt_yusin/actions/design_api_review.py new file mode 100644 index 0000000000000000000000000000000000000000..fcad4dba1df6913e6660789b8b5c1fb0723feb4a --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/design_api_review.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 19:31 +@Author : alexanderwu +@File : design_api_review.py +""" + +from typing import Optional + +from metagpt_yusin.actions.action import Action + + +class DesignReview(Action): + name: str = "DesignReview" + i_context: Optional[str] = None + + async def run(self, prd, api_design): + prompt = ( + f"Here is the Product Requirement Document (PRD):\n\n{prd}\n\nHere is the list of APIs designed " + f"based on this PRD:\n\n{api_design}\n\nPlease review whether this API design meets the requirements" + f" of the PRD, and whether it complies with good design practices." + ) + + api_review = await self._aask(prompt) + return api_review diff --git a/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/__init__-checkpoint.py b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/ask_review-checkpoint.py b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/ask_review-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..850fc811795cac25fedd13f4b47b4901e875813e --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/ask_review-checkpoint.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Tuple + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message, Plan + + +class ReviewConst: + TASK_REVIEW_TRIGGER = "task" + CODE_REVIEW_TRIGGER = "code" + CONTINUE_WORDS = ["confirm", "continue", "c", "yes", "y"] + CHANGE_WORDS = ["change"] + EXIT_WORDS = ["exit"] + TASK_REVIEW_INSTRUCTION = ( + f"If you want to change, add, delete a task or merge tasks in the plan, say '{CHANGE_WORDS[0]} task task_id or current task, ... (things to change)' " + f"If you confirm the output from the current task and wish to continue, type: {CONTINUE_WORDS[0]}" + ) + CODE_REVIEW_INSTRUCTION = ( + f"If you want the codes to be rewritten, say '{CHANGE_WORDS[0]} ... (your change advice)' " + f"If you want to leave it as is, type: {CONTINUE_WORDS[0]} or {CONTINUE_WORDS[1]}" + ) + EXIT_INSTRUCTION = f"If you want to terminate the process, type: {EXIT_WORDS[0]}" + + +class AskReview(Action): + async def run( + self, context: list[Message] = [], plan: Plan = None, trigger: str = ReviewConst.TASK_REVIEW_TRIGGER + ) -> Tuple[str, bool]: + if plan: + logger.info("Current overall plan:") + logger.info( + "\n".join( + [f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks] + ) + ) + + logger.info("Most recent context:") + latest_action = context[-1].cause_by if context and context[-1].cause_by else "" + review_instruction = ( + ReviewConst.TASK_REVIEW_INSTRUCTION + if trigger == ReviewConst.TASK_REVIEW_TRIGGER + else ReviewConst.CODE_REVIEW_INSTRUCTION + ) + prompt = ( + f"This is a <{trigger}> review. Please review output from {latest_action}\n" + f"{review_instruction}\n" + f"{ReviewConst.EXIT_INSTRUCTION}\n" + "Please type your review below:\n" + ) + + rsp = input(prompt) + + if rsp.lower() in ReviewConst.EXIT_WORDS: + exit() + + # Confirmation can be one of "confirm", "continue", "c", "yes", "y" exactly, or sentences containing "confirm". + # One could say "confirm this task, but change the next task to ..." + confirmed = rsp.lower() in ReviewConst.CONTINUE_WORDS or ReviewConst.CONTINUE_WORDS[0] in rsp.lower() + + return rsp, confirmed diff --git a/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/execute_nb_code-checkpoint.py b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/execute_nb_code-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..dac471b2df18df90de0fe104358bb064f99dbae1 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/execute_nb_code-checkpoint.py @@ -0,0 +1,256 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/17 14:22:15 +@Author : orange-crow +@File : execute_nb_code.py +""" +from __future__ import annotations + +import asyncio +import base64 +import re +from typing import Literal, Tuple + +import nbformat +from nbclient import NotebookClient +from nbclient.exceptions import CellTimeoutError, DeadKernelError +from nbformat import NotebookNode +from nbformat.v4 import new_code_cell, new_markdown_cell, new_output +from rich.box import MINIMAL +from rich.console import Console, Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.syntax import Syntax + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger + + +class ExecuteNbCode(Action): + """execute notebook code block, return result to llm, and display it.""" + + nb: NotebookNode + nb_client: NotebookClient + console: Console + interaction: str + timeout: int = 600 + + def __init__( + self, + nb=nbformat.v4.new_notebook(), + timeout=600, + ): + super().__init__( + nb=nb, + nb_client=NotebookClient(nb, timeout=timeout), + timeout=timeout, + console=Console(), + interaction=("ipython" if self.is_ipython() else "terminal"), + ) + + async def build(self): + if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): + self.nb_client.create_kernel_manager() + self.nb_client.start_new_kernel() + self.nb_client.start_new_kernel_client() + + async def terminate(self): + """kill NotebookClient""" + if self.nb_client.km is not None and await self.nb_client.km.is_alive(): + await self.nb_client.km.shutdown_kernel(now=True) + await self.nb_client.km.cleanup_resources() + + channels = [ + self.nb_client.kc.stdin_channel, # The channel for handling standard input to the kernel. + self.nb_client.kc.hb_channel, # The channel for heartbeat communication between the kernel and client. + self.nb_client.kc.control_channel, # The channel for controlling the kernel. + ] + + # Stops all the running channels for this kernel + for channel in channels: + if channel.is_alive(): + channel.stop() + + self.nb_client.kc = None + self.nb_client.km = None + + async def reset(self): + """reset NotebookClient""" + await self.terminate() + + # sleep 1s to wait for the kernel to be cleaned up completely + await asyncio.sleep(1) + await self.build() + self.nb_client = NotebookClient(self.nb, timeout=self.timeout) + + def add_code_cell(self, code: str): + self.nb.cells.append(new_code_cell(source=code)) + + def add_markdown_cell(self, markdown: str): + self.nb.cells.append(new_markdown_cell(source=markdown)) + + def _display(self, code: str, language: Literal["python", "markdown"] = "python"): + if language == "python": + code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True) + self.console.print(code) + elif language == "markdown": + display_markdown(code) + else: + raise ValueError(f"Only support for python, markdown, but got {language}") + + def add_output_to_cell(self, cell: NotebookNode, output: str): + """add outputs of code execution to notebook cell.""" + if "outputs" not in cell: + cell["outputs"] = [] + else: + cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) + + def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]: + """Parses the outputs received from notebook execution.""" + assert isinstance(outputs, list) + parsed_output, is_success = [], True + for i, output in enumerate(outputs): + output_text = "" + if output["output_type"] == "stream" and not any( + tag in output["text"] + for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt", "DEBUG"] + ): + output_text = output["text"] + elif output["output_type"] == "display_data": + if "image/png" in output["data"]: + self.show_bytes_figure(output["data"]["image/png"], self.interaction) + else: + logger.info( + f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..." + ) + elif output["output_type"] == "execute_result": + output_text = output["data"]["text/plain"] + elif output["output_type"] == "error": + output_text, is_success = "\n".join(output["traceback"]), False + + # handle coroutines that are not executed asynchronously + if output_text.strip().startswith(" bool: + try: + # 如果在Jupyter Notebook中运行,__file__ 变量不存在 + from IPython import get_ipython + + if get_ipython() is not None and "IPKernelApp" in get_ipython().config: + return True + else: + return False + except NameError: + return False + + async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]: + """set timeout for run code. + returns the success or failure of the cell execution, and an optional error message. + """ + try: + await self.nb_client.async_execute_cell(cell, cell_index) + return self.parse_outputs(self.nb.cells[-1].outputs) + except CellTimeoutError: + assert self.nb_client.km is not None + await self.nb_client.km.interrupt_kernel() + await asyncio.sleep(1) + error_msg = "Cell execution timed out: Execution exceeded the time limit and was stopped; consider optimizing your code for better performance." + return False, error_msg + except DeadKernelError: + await self.reset() + return False, "DeadKernelError" + except Exception: + return self.parse_outputs(self.nb.cells[-1].outputs) + + async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]: + """ + return the output of code execution, and a success indicator (bool) of code execution. + """ + self._display(code, language) + + if language == "python": + # add code to the notebook + self.add_code_cell(code=code) + + # build code executor + await self.build() + + # run code + cell_index = len(self.nb.cells) - 1 + success, outputs = await self.run_cell(self.nb.cells[-1], cell_index) + + if "!pip" in code: + success = False + + return outputs, success + + elif language == "markdown": + # add markdown content to markdown cell in a notebook. + self.add_markdown_cell(code) + # return True, beacuse there is no execution failure for markdown cell. + return code, True + else: + raise ValueError(f"Only support for language: python, markdown, but got {language}, ") + + +def remove_escape_and_color_codes(input_str: str): + # 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码 + # Use regular expressions to get rid of escape characters and color codes in jupyter notebook output. + pattern = re.compile(r"\x1b\[[0-9;]*[mK]") + result = pattern.sub("", input_str) + return result + + +def display_markdown(content: str): + # Use regular expressions to match blocks of code one by one. + matches = re.finditer(r"```(.+?)```", content, re.DOTALL) + start_index = 0 + content_panels = [] + # Set the text background color and text color. + style = "black on white" + # Print the matching text and code one by one. + for match in matches: + text_content = content[start_index : match.start()].strip() + code_content = match.group(0).strip()[3:-3] # Remove triple backticks + + if text_content: + content_panels.append(Panel(Markdown(text_content), style=style, box=MINIMAL)) + + if code_content: + content_panels.append(Panel(Markdown(f"```{code_content}"), style=style, box=MINIMAL)) + start_index = match.end() + + # Print remaining text (if any). + remaining_text = content[start_index:].strip() + if remaining_text: + content_panels.append(Panel(Markdown(remaining_text), style=style, box=MINIMAL)) + + # Display all panels in Live mode. + with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live: + live.update(Group(*content_panels)) + live.refresh() diff --git a/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/write_analysis_code-checkpoint.py b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/write_analysis_code-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..cb276e8b58cd4c044ee12627034ea19faea80481 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/write_analysis_code-checkpoint.py @@ -0,0 +1,73 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 13:19:39 +@Author : orange-crow +@File : write_analysis_code.py +""" +from __future__ import annotations + +import json + +from metagpt_yusin.actions import Action +from metagpt_yusin.prompts.di.write_analysis_code import ( + CHECK_DATA_PROMPT, + DEBUG_REFLECTION_EXAMPLE, + INTERPRETER_SYSTEM_MSG, + REFLECTION_PROMPT, + REFLECTION_SYSTEM_MSG, + STRUCTUAL_PROMPT, +) +from metagpt_yusin.schema import Message, Plan +from metagpt_yusin.utils.common import CodeParser, remove_comments + + +class WriteAnalysisCode(Action): + async def _debug_with_reflection(self, context: list[Message], working_memory: list[Message]): + reflection_prompt = REFLECTION_PROMPT.format( + debug_example=DEBUG_REFLECTION_EXAMPLE, + context=context, + previous_impl=working_memory, + ) + + rsp = await self._aask(reflection_prompt, system_msgs=[REFLECTION_SYSTEM_MSG]) + reflection = json.loads(CodeParser.parse_code(block=None, text=rsp)) + + return reflection["improved_impl"] + + async def run( + self, + user_requirement: str, + plan_status: str = "", + tool_info: str = "", + working_memory: list[Message] = None, + use_reflection: bool = False, + **kwargs, + ) -> str: + structual_prompt = STRUCTUAL_PROMPT.format( + user_requirement=user_requirement, + plan_status=plan_status, + tool_info=tool_info, + ) + + working_memory = working_memory or [] + context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory) + + # LLM call + if use_reflection: + code = await self._debug_with_reflection(context=context, working_memory=working_memory) + else: + rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs) # also out from here + code = CodeParser.parse_code(block=None, text=rsp) + + return code + + +class CheckData(Action): + async def run(self, plan: Plan) -> dict: + finished_tasks = plan.get_finished_tasks() + code_written = [remove_comments(task.code) for task in finished_tasks] + code_written = "\n\n".join(code_written) + prompt = CHECK_DATA_PROMPT.format(code_written=code_written) + rsp = await self._aask(prompt) + code = CodeParser.parse_code(block=None, text=rsp) + return code diff --git a/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/write_plan-checkpoint.py b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/write_plan-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..3f87ae2b9dee3b1be306af431d447e39a5d4524b --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/.ipynb_checkpoints/write_plan-checkpoint.py @@ -0,0 +1,90 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 11:24:03 +@Author : orange-crow +@File : plan.py +""" +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Tuple + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message, Plan, Task +from metagpt_yusin.strategy.task_type import TaskType +from metagpt_yusin.utils.common import CodeParser + + +class WritePlan(Action): + PROMPT_TEMPLATE: str = """ + # Context: + {context} + # Available Task Types: + {task_type_desc} + # Task: + Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks. + If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan. + If you encounter errors on the current task, revise and output the current single task only. + Output a list of jsons following the format: + ```json + [ + {{ + "task_id": str = "unique identifier for a task in plan, can be an ordinal", + "dependent_task_ids": list[str] = "ids of tasks prerequisite to this task", + "instruction": "what you should do in this task, one short phrase or sentence", + "task_type": "type of this task, should be one of Available Task Types", + }}, + ... + ] + ``` + """ + + async def run(self, context: list[Message], max_tasks: int = 5) -> str: + task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]) + prompt = self.PROMPT_TEMPLATE.format( + context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc + ) + #print('2333333333333333333333333333333333333333333333333333333333333333333333333333') + rsp = await self._aask(prompt) + #print('44444444444444444444444444444444444444444444444444444444444444444444444444') + rsp = CodeParser.parse_code(block=None, text=rsp) + return rsp + + +def update_plan_from_rsp(rsp: str, current_plan: Plan): + rsp = json.loads(rsp) + tasks = [Task(**task_config) for task_config in rsp] + + #print('-----------------------------------------------------------') + #print(tasks) + #print('-----------------------------------------------------------') + + if len(tasks) == 1 or tasks[0].dependent_task_ids: + if tasks[0].dependent_task_ids and len(tasks) > 1: + # tasks[0].dependent_task_ids means the generated tasks are not a complete plan + # for they depend on tasks in the current plan, in this case, we only support updating one task each time + logger.warning( + "Current plan will take only the first generated task if the generated tasks are not a complete plan" + ) + # handle a single task + if current_plan.has_task_id(tasks[0].task_id): + # replace an existing task + current_plan.replace_task(tasks[0]) + else: + # append one task + current_plan.append_task(tasks[0]) + + else: + # add tasks in general + current_plan.add_tasks(tasks) + + +def precheck_update_plan_from_rsp(rsp: str, current_plan: Plan) -> Tuple[bool, str]: + temp_plan = deepcopy(current_plan) + try: + update_plan_from_rsp(rsp, temp_plan) + return True, "" + except Exception as e: + return False, e diff --git a/notebook_dir/metagpt_yusin/actions/di/__init__.py b/notebook_dir/metagpt_yusin/actions/di/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/notebook_dir/metagpt_yusin/actions/di/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/di/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db983dc61a23530c0938907fd4a64f6bdb2ba5ee Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/di/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/di/__pycache__/ask_review.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/di/__pycache__/ask_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa40bf3c11d689eb71aec73e616cbdeda303873 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/di/__pycache__/ask_review.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/di/__pycache__/execute_nb_code.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/di/__pycache__/execute_nb_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83877c1c1196a4d530e1139d2a5fe406ee16a192 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/di/__pycache__/execute_nb_code.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/di/__pycache__/write_analysis_code.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/di/__pycache__/write_analysis_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4d2881b040f6f36a772dc33951cb66bd0763fa Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/di/__pycache__/write_analysis_code.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/di/__pycache__/write_plan.cpython-39.pyc b/notebook_dir/metagpt_yusin/actions/di/__pycache__/write_plan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3cc95c1e32ecfa3e8d29dd38ff5df290d75d118 Binary files /dev/null and b/notebook_dir/metagpt_yusin/actions/di/__pycache__/write_plan.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/actions/di/ask_review.py b/notebook_dir/metagpt_yusin/actions/di/ask_review.py new file mode 100644 index 0000000000000000000000000000000000000000..850fc811795cac25fedd13f4b47b4901e875813e --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/ask_review.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Tuple + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message, Plan + + +class ReviewConst: + TASK_REVIEW_TRIGGER = "task" + CODE_REVIEW_TRIGGER = "code" + CONTINUE_WORDS = ["confirm", "continue", "c", "yes", "y"] + CHANGE_WORDS = ["change"] + EXIT_WORDS = ["exit"] + TASK_REVIEW_INSTRUCTION = ( + f"If you want to change, add, delete a task or merge tasks in the plan, say '{CHANGE_WORDS[0]} task task_id or current task, ... (things to change)' " + f"If you confirm the output from the current task and wish to continue, type: {CONTINUE_WORDS[0]}" + ) + CODE_REVIEW_INSTRUCTION = ( + f"If you want the codes to be rewritten, say '{CHANGE_WORDS[0]} ... (your change advice)' " + f"If you want to leave it as is, type: {CONTINUE_WORDS[0]} or {CONTINUE_WORDS[1]}" + ) + EXIT_INSTRUCTION = f"If you want to terminate the process, type: {EXIT_WORDS[0]}" + + +class AskReview(Action): + async def run( + self, context: list[Message] = [], plan: Plan = None, trigger: str = ReviewConst.TASK_REVIEW_TRIGGER + ) -> Tuple[str, bool]: + if plan: + logger.info("Current overall plan:") + logger.info( + "\n".join( + [f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks] + ) + ) + + logger.info("Most recent context:") + latest_action = context[-1].cause_by if context and context[-1].cause_by else "" + review_instruction = ( + ReviewConst.TASK_REVIEW_INSTRUCTION + if trigger == ReviewConst.TASK_REVIEW_TRIGGER + else ReviewConst.CODE_REVIEW_INSTRUCTION + ) + prompt = ( + f"This is a <{trigger}> review. Please review output from {latest_action}\n" + f"{review_instruction}\n" + f"{ReviewConst.EXIT_INSTRUCTION}\n" + "Please type your review below:\n" + ) + + rsp = input(prompt) + + if rsp.lower() in ReviewConst.EXIT_WORDS: + exit() + + # Confirmation can be one of "confirm", "continue", "c", "yes", "y" exactly, or sentences containing "confirm". + # One could say "confirm this task, but change the next task to ..." + confirmed = rsp.lower() in ReviewConst.CONTINUE_WORDS or ReviewConst.CONTINUE_WORDS[0] in rsp.lower() + + return rsp, confirmed diff --git a/notebook_dir/metagpt_yusin/actions/di/execute_nb_code.py b/notebook_dir/metagpt_yusin/actions/di/execute_nb_code.py new file mode 100644 index 0000000000000000000000000000000000000000..dac471b2df18df90de0fe104358bb064f99dbae1 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/execute_nb_code.py @@ -0,0 +1,256 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/17 14:22:15 +@Author : orange-crow +@File : execute_nb_code.py +""" +from __future__ import annotations + +import asyncio +import base64 +import re +from typing import Literal, Tuple + +import nbformat +from nbclient import NotebookClient +from nbclient.exceptions import CellTimeoutError, DeadKernelError +from nbformat import NotebookNode +from nbformat.v4 import new_code_cell, new_markdown_cell, new_output +from rich.box import MINIMAL +from rich.console import Console, Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.syntax import Syntax + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger + + +class ExecuteNbCode(Action): + """execute notebook code block, return result to llm, and display it.""" + + nb: NotebookNode + nb_client: NotebookClient + console: Console + interaction: str + timeout: int = 600 + + def __init__( + self, + nb=nbformat.v4.new_notebook(), + timeout=600, + ): + super().__init__( + nb=nb, + nb_client=NotebookClient(nb, timeout=timeout), + timeout=timeout, + console=Console(), + interaction=("ipython" if self.is_ipython() else "terminal"), + ) + + async def build(self): + if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): + self.nb_client.create_kernel_manager() + self.nb_client.start_new_kernel() + self.nb_client.start_new_kernel_client() + + async def terminate(self): + """kill NotebookClient""" + if self.nb_client.km is not None and await self.nb_client.km.is_alive(): + await self.nb_client.km.shutdown_kernel(now=True) + await self.nb_client.km.cleanup_resources() + + channels = [ + self.nb_client.kc.stdin_channel, # The channel for handling standard input to the kernel. + self.nb_client.kc.hb_channel, # The channel for heartbeat communication between the kernel and client. + self.nb_client.kc.control_channel, # The channel for controlling the kernel. + ] + + # Stops all the running channels for this kernel + for channel in channels: + if channel.is_alive(): + channel.stop() + + self.nb_client.kc = None + self.nb_client.km = None + + async def reset(self): + """reset NotebookClient""" + await self.terminate() + + # sleep 1s to wait for the kernel to be cleaned up completely + await asyncio.sleep(1) + await self.build() + self.nb_client = NotebookClient(self.nb, timeout=self.timeout) + + def add_code_cell(self, code: str): + self.nb.cells.append(new_code_cell(source=code)) + + def add_markdown_cell(self, markdown: str): + self.nb.cells.append(new_markdown_cell(source=markdown)) + + def _display(self, code: str, language: Literal["python", "markdown"] = "python"): + if language == "python": + code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True) + self.console.print(code) + elif language == "markdown": + display_markdown(code) + else: + raise ValueError(f"Only support for python, markdown, but got {language}") + + def add_output_to_cell(self, cell: NotebookNode, output: str): + """add outputs of code execution to notebook cell.""" + if "outputs" not in cell: + cell["outputs"] = [] + else: + cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) + + def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]: + """Parses the outputs received from notebook execution.""" + assert isinstance(outputs, list) + parsed_output, is_success = [], True + for i, output in enumerate(outputs): + output_text = "" + if output["output_type"] == "stream" and not any( + tag in output["text"] + for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt", "DEBUG"] + ): + output_text = output["text"] + elif output["output_type"] == "display_data": + if "image/png" in output["data"]: + self.show_bytes_figure(output["data"]["image/png"], self.interaction) + else: + logger.info( + f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..." + ) + elif output["output_type"] == "execute_result": + output_text = output["data"]["text/plain"] + elif output["output_type"] == "error": + output_text, is_success = "\n".join(output["traceback"]), False + + # handle coroutines that are not executed asynchronously + if output_text.strip().startswith(" bool: + try: + # 如果在Jupyter Notebook中运行,__file__ 变量不存在 + from IPython import get_ipython + + if get_ipython() is not None and "IPKernelApp" in get_ipython().config: + return True + else: + return False + except NameError: + return False + + async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]: + """set timeout for run code. + returns the success or failure of the cell execution, and an optional error message. + """ + try: + await self.nb_client.async_execute_cell(cell, cell_index) + return self.parse_outputs(self.nb.cells[-1].outputs) + except CellTimeoutError: + assert self.nb_client.km is not None + await self.nb_client.km.interrupt_kernel() + await asyncio.sleep(1) + error_msg = "Cell execution timed out: Execution exceeded the time limit and was stopped; consider optimizing your code for better performance." + return False, error_msg + except DeadKernelError: + await self.reset() + return False, "DeadKernelError" + except Exception: + return self.parse_outputs(self.nb.cells[-1].outputs) + + async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]: + """ + return the output of code execution, and a success indicator (bool) of code execution. + """ + self._display(code, language) + + if language == "python": + # add code to the notebook + self.add_code_cell(code=code) + + # build code executor + await self.build() + + # run code + cell_index = len(self.nb.cells) - 1 + success, outputs = await self.run_cell(self.nb.cells[-1], cell_index) + + if "!pip" in code: + success = False + + return outputs, success + + elif language == "markdown": + # add markdown content to markdown cell in a notebook. + self.add_markdown_cell(code) + # return True, beacuse there is no execution failure for markdown cell. + return code, True + else: + raise ValueError(f"Only support for language: python, markdown, but got {language}, ") + + +def remove_escape_and_color_codes(input_str: str): + # 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码 + # Use regular expressions to get rid of escape characters and color codes in jupyter notebook output. + pattern = re.compile(r"\x1b\[[0-9;]*[mK]") + result = pattern.sub("", input_str) + return result + + +def display_markdown(content: str): + # Use regular expressions to match blocks of code one by one. + matches = re.finditer(r"```(.+?)```", content, re.DOTALL) + start_index = 0 + content_panels = [] + # Set the text background color and text color. + style = "black on white" + # Print the matching text and code one by one. + for match in matches: + text_content = content[start_index : match.start()].strip() + code_content = match.group(0).strip()[3:-3] # Remove triple backticks + + if text_content: + content_panels.append(Panel(Markdown(text_content), style=style, box=MINIMAL)) + + if code_content: + content_panels.append(Panel(Markdown(f"```{code_content}"), style=style, box=MINIMAL)) + start_index = match.end() + + # Print remaining text (if any). + remaining_text = content[start_index:].strip() + if remaining_text: + content_panels.append(Panel(Markdown(remaining_text), style=style, box=MINIMAL)) + + # Display all panels in Live mode. + with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live: + live.update(Group(*content_panels)) + live.refresh() diff --git a/notebook_dir/metagpt_yusin/actions/di/write_analysis_code.py b/notebook_dir/metagpt_yusin/actions/di/write_analysis_code.py new file mode 100644 index 0000000000000000000000000000000000000000..cb276e8b58cd4c044ee12627034ea19faea80481 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/write_analysis_code.py @@ -0,0 +1,73 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 13:19:39 +@Author : orange-crow +@File : write_analysis_code.py +""" +from __future__ import annotations + +import json + +from metagpt_yusin.actions import Action +from metagpt_yusin.prompts.di.write_analysis_code import ( + CHECK_DATA_PROMPT, + DEBUG_REFLECTION_EXAMPLE, + INTERPRETER_SYSTEM_MSG, + REFLECTION_PROMPT, + REFLECTION_SYSTEM_MSG, + STRUCTUAL_PROMPT, +) +from metagpt_yusin.schema import Message, Plan +from metagpt_yusin.utils.common import CodeParser, remove_comments + + +class WriteAnalysisCode(Action): + async def _debug_with_reflection(self, context: list[Message], working_memory: list[Message]): + reflection_prompt = REFLECTION_PROMPT.format( + debug_example=DEBUG_REFLECTION_EXAMPLE, + context=context, + previous_impl=working_memory, + ) + + rsp = await self._aask(reflection_prompt, system_msgs=[REFLECTION_SYSTEM_MSG]) + reflection = json.loads(CodeParser.parse_code(block=None, text=rsp)) + + return reflection["improved_impl"] + + async def run( + self, + user_requirement: str, + plan_status: str = "", + tool_info: str = "", + working_memory: list[Message] = None, + use_reflection: bool = False, + **kwargs, + ) -> str: + structual_prompt = STRUCTUAL_PROMPT.format( + user_requirement=user_requirement, + plan_status=plan_status, + tool_info=tool_info, + ) + + working_memory = working_memory or [] + context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory) + + # LLM call + if use_reflection: + code = await self._debug_with_reflection(context=context, working_memory=working_memory) + else: + rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs) # also out from here + code = CodeParser.parse_code(block=None, text=rsp) + + return code + + +class CheckData(Action): + async def run(self, plan: Plan) -> dict: + finished_tasks = plan.get_finished_tasks() + code_written = [remove_comments(task.code) for task in finished_tasks] + code_written = "\n\n".join(code_written) + prompt = CHECK_DATA_PROMPT.format(code_written=code_written) + rsp = await self._aask(prompt) + code = CodeParser.parse_code(block=None, text=rsp) + return code diff --git a/notebook_dir/metagpt_yusin/actions/di/write_plan.py b/notebook_dir/metagpt_yusin/actions/di/write_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..3f87ae2b9dee3b1be306af431d447e39a5d4524b --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/di/write_plan.py @@ -0,0 +1,90 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 11:24:03 +@Author : orange-crow +@File : plan.py +""" +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Tuple + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message, Plan, Task +from metagpt_yusin.strategy.task_type import TaskType +from metagpt_yusin.utils.common import CodeParser + + +class WritePlan(Action): + PROMPT_TEMPLATE: str = """ + # Context: + {context} + # Available Task Types: + {task_type_desc} + # Task: + Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks. + If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan. + If you encounter errors on the current task, revise and output the current single task only. + Output a list of jsons following the format: + ```json + [ + {{ + "task_id": str = "unique identifier for a task in plan, can be an ordinal", + "dependent_task_ids": list[str] = "ids of tasks prerequisite to this task", + "instruction": "what you should do in this task, one short phrase or sentence", + "task_type": "type of this task, should be one of Available Task Types", + }}, + ... + ] + ``` + """ + + async def run(self, context: list[Message], max_tasks: int = 5) -> str: + task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]) + prompt = self.PROMPT_TEMPLATE.format( + context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc + ) + #print('2333333333333333333333333333333333333333333333333333333333333333333333333333') + rsp = await self._aask(prompt) + #print('44444444444444444444444444444444444444444444444444444444444444444444444444') + rsp = CodeParser.parse_code(block=None, text=rsp) + return rsp + + +def update_plan_from_rsp(rsp: str, current_plan: Plan): + rsp = json.loads(rsp) + tasks = [Task(**task_config) for task_config in rsp] + + #print('-----------------------------------------------------------') + #print(tasks) + #print('-----------------------------------------------------------') + + if len(tasks) == 1 or tasks[0].dependent_task_ids: + if tasks[0].dependent_task_ids and len(tasks) > 1: + # tasks[0].dependent_task_ids means the generated tasks are not a complete plan + # for they depend on tasks in the current plan, in this case, we only support updating one task each time + logger.warning( + "Current plan will take only the first generated task if the generated tasks are not a complete plan" + ) + # handle a single task + if current_plan.has_task_id(tasks[0].task_id): + # replace an existing task + current_plan.replace_task(tasks[0]) + else: + # append one task + current_plan.append_task(tasks[0]) + + else: + # add tasks in general + current_plan.add_tasks(tasks) + + +def precheck_update_plan_from_rsp(rsp: str, current_plan: Plan) -> Tuple[bool, str]: + temp_plan = deepcopy(current_plan) + try: + update_plan_from_rsp(rsp, temp_plan) + return True, "" + except Exception as e: + return False, e diff --git a/notebook_dir/metagpt_yusin/actions/execute_task.py b/notebook_dir/metagpt_yusin/actions/execute_task.py new file mode 100644 index 0000000000000000000000000000000000000000..51be4d144fbcc11cb794268d0048cc3072184dd5 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/execute_task.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 12:26 +@Author : femto Zheng +@File : execute_task.py +""" + + +from metagpt_yusin.actions import Action +from metagpt_yusin.schema import Message + + +class ExecuteTask(Action): + name: str = "ExecuteTask" + i_context: list[Message] = [] + + async def run(self, *args, **kwargs): + pass diff --git a/notebook_dir/metagpt_yusin/actions/fix_bug.py b/notebook_dir/metagpt_yusin/actions/fix_bug.py new file mode 100644 index 0000000000000000000000000000000000000000..585e05c174ba5d18a5584c3505370ae08f804d39 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/fix_bug.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023-12-12 +@Author : mashenquan +@File : fix_bug.py +""" +from metagpt_yusin.actions import Action + + +class FixBug(Action): + """Fix bug action without any implementation details""" + + name: str = "FixBug" diff --git a/notebook_dir/metagpt_yusin/actions/generate_questions.py b/notebook_dir/metagpt_yusin/actions/generate_questions.py new file mode 100644 index 0000000000000000000000000000000000000000..fc25c1360f9cc48ef2f505a2cf33d67d9930bfa4 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/generate_questions.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@File : generate_questions.py +""" +from metagpt_yusin.actions import Action +from metagpt_yusin.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) + + +class GenerateQuestions(Action): + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" + + name: str = "GenerateQuestions" + + async def run(self, context) -> ActionNode: + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/notebook_dir/metagpt_yusin/actions/invoice_ocr.py b/notebook_dir/metagpt_yusin/actions/invoice_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..9686a9b5b8a12a598e917ee7e3a6e5e0919246dd --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/invoice_ocr.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 18:10:20 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Actions of the invoice ocr assistant. +""" + +import os +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Optional + +import pandas as pd +from paddleocr import PaddleOCR + +from metagpt_yusin.actions import Action +from metagpt_yusin.const import INVOICE_OCR_TABLE_PATH +from metagpt_yusin.logs import logger +from metagpt_yusin.prompts.invoice_ocr import ( + EXTRACT_OCR_MAIN_INFO_PROMPT, + REPLY_OCR_QUESTION_PROMPT, +) +from metagpt_yusin.utils.common import OutputParser +from metagpt_yusin.utils.file import File + + +class InvoiceOCR(Action): + """Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language for OCR output. Defaults to "ch" (Chinese). + + """ + + name: str = "InvoiceOCR" + i_context: Optional[str] = None + + @staticmethod + async def _check_file_type(file_path: Path) -> str: + """Check the file type of the given filename. + + Args: + file_path: The path of the file. + + Returns: + The file type based on FileExtensionType enum. + + Raises: + Exception: If the file format is not zip, pdf, png, or jpg. + """ + ext = file_path.suffix + if ext not in [".zip", ".pdf", ".png", ".jpg"]: + raise Exception("The invoice format is not zip, pdf, png, or jpg") + + return ext + + @staticmethod + async def _unzip(file_path: Path) -> Path: + """Unzip a file and return the path to the unzipped directory. + + Args: + file_path: The path to the zip file. + + Returns: + The path to the unzipped directory. + """ + file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S") + with zipfile.ZipFile(file_path, "r") as zip_ref: + for zip_info in zip_ref.infolist(): + # Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code + relative_name = Path(zip_info.filename.encode("cp437").decode("gbk")) + if relative_name.suffix: + full_filename = file_directory / relative_name + await File.write(full_filename.parent, relative_name.name, zip_ref.read(zip_info.filename)) + + logger.info(f"unzip_path: {file_directory}") + return file_directory + + @staticmethod + async def _ocr(invoice_file_path: Path): + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(invoice_file_path), cls=True) + for result in ocr_result[0]: + result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs + return ocr_result + + async def run(self, file_path: Path, *args, **kwargs) -> list: + """Execute the action to identify invoice files through OCR. + + Args: + file_path: The path to the input file. + + Returns: + A list of OCR results. + """ + file_ext = await self._check_file_type(file_path) + + if file_ext == ".zip": + # OCR recognizes zip batch files + unzip_path = await self._unzip(file_path) + ocr_list = [] + for root, _, files in os.walk(unzip_path): + for filename in files: + invoice_file_path = Path(root) / Path(filename) + # Identify files that match the type + if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]: + ocr_result = await self._ocr(str(invoice_file_path)) + ocr_list.append(ocr_result) + return ocr_list + + else: + # OCR identifies single file + ocr_result = await self._ocr(file_path) + return [ocr_result] + + +class GenerateTable(Action): + """Action class for generating tables from OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for the generated table. Defaults to "ch" (Chinese). + + """ + + name: str = "GenerateTable" + i_context: Optional[str] = None + language: str = "ch" + + async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: + """Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file. + + Args: + ocr_results: A list of OCR results obtained from invoice processing. + filename: The name of the output Excel file. + + Returns: + A dictionary containing the invoice information. + + """ + table_data = [] + pathname = INVOICE_OCR_TABLE_PATH + pathname.mkdir(parents=True, exist_ok=True) + + for ocr_result in ocr_results: + # Extract invoice OCR main information + prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language) + ocr_info = await self._aask(prompt=prompt) + invoice_data = OutputParser.extract_struct(ocr_info, dict) + if invoice_data: + table_data.append(invoice_data) + + # Generate Excel file + filename = f"{filename.split('.')[0]}.xlsx" + full_filename = f"{pathname}/{filename}" + df = pd.DataFrame(table_data) + df.to_excel(full_filename, index=False) + return table_data + + +class ReplyQuestion(Action): + """Action class for generating replies to questions based on OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for generating the reply. Defaults to "ch" (Chinese). + + """ + + language: str = "ch" + + async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: + """Reply to questions based on ocr results. + + Args: + query: The question for which a reply is generated. + ocr_result: A list of OCR results. + + Returns: + A reply result of string type. + """ + prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language) + resp = await self._aask(prompt=prompt) + return resp diff --git a/notebook_dir/metagpt_yusin/actions/prepare_documents.py b/notebook_dir/metagpt_yusin/actions/prepare_documents.py new file mode 100644 index 0000000000000000000000000000000000000000..676d458189399febcb67bf7d0987f0f1941b8d58 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/prepare_documents.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/11/20 +@Author : mashenquan +@File : prepare_documents.py +@Desc: PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt. + RFC 135 2.2.3.5.1. +""" +import shutil +from pathlib import Path +from typing import Optional + +from metagpt_yusin.actions import Action, ActionOutput +from metagpt_yusin.const import REQUIREMENT_FILENAME +from metagpt_yusin.utils.file_repository import FileRepository +from metagpt_yusin.utils.git_repository import GitRepository +from metagpt_yusin.utils.project_repo import ProjectRepo + + +class PrepareDocuments(Action): + """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + + name: str = "PrepareDocuments" + i_context: Optional[str] = None + + @property + def config(self): + return self.context.config + + def _init_repo(self): + """Initialize the Git environment.""" + if not self.config.project_path: + name = self.config.project_name or FileRepository.new_filename() + path = Path(self.config.workspace.path) / name + else: + path = Path(self.config.project_path) + if path.exists() and not self.config.inc: + shutil.rmtree(path) + self.config.project_path = path + self.context.git_repo = GitRepository(local_path=path, auto_init=True) + self.context.repo = ProjectRepo(self.context.git_repo) + + async def run(self, with_messages, **kwargs): + """Create and initialize the workspace folder, initialize the Git environment.""" + self._init_repo() + + # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. + doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) + # Send a Message notification to the WritePRD action, instructing it to process requirements using + # `docs/requirement.txt` and `docs/prd/`. + return ActionOutput(content=doc.content, instruct_content=doc) diff --git a/notebook_dir/metagpt_yusin/actions/prepare_interview.py b/notebook_dir/metagpt_yusin/actions/prepare_interview.py new file mode 100644 index 0000000000000000000000000000000000000000..2142c0062d7346207c5baf5f0c5be23305cd255f --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/prepare_interview.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/19 15:02 +@Author : DevXiaolan +@File : prepare_interview.py +""" +from metagpt_yusin.actions import Action +from metagpt_yusin.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="""Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; +Requirement: Provide a list of questions for the interviewer to ask the interviewee, by reading the resume of the interviewee in the context. +Attention: Provide as markdown block as the format above, at least 10 questions.""", + example=["1. What ...", "2. How ..."], +) + + +class PrepareInterview(Action): + name: str = "PrepareInterview" + + async def run(self, context): + return await QUESTIONS.fill(context=context, llm=self.llm) diff --git a/notebook_dir/metagpt_yusin/actions/project_management.py b/notebook_dir/metagpt_yusin/actions/project_management.py new file mode 100644 index 0000000000000000000000000000000000000000..697c19637715ffbc633e8ae8b5b4a21c4f2fe71e --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/project_management.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 19:12 +@Author : alexanderwu +@File : project_management.py +@Modified By: mashenquan, 2023/11/27. + 1. Divide the context into three components: legacy code, unit test code, and console log. + 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. + 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. +""" + +import json +from typing import Optional + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.actions.action_output import ActionOutput +from metagpt_yusin.actions.project_management_an import PM_NODE, REFINED_PM_NODE +from metagpt_yusin.const import PACKAGE_REQUIREMENTS_FILENAME +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Document, Documents + +NEW_REQ_TEMPLATE = """ +### Legacy Content +{old_task} + +### New Requirements +{context} +""" + + +class WriteTasks(Action): + name: str = "CreateTasks" + i_context: Optional[str] = None + + async def run(self, with_messages): + changed_system_designs = self.repo.docs.system_design.changed_files + changed_tasks = self.repo.docs.task.changed_files + change_files = Documents() + # Rewrite the system designs that have undergone changes based on the git head diff under + # `docs/system_designs/`. + for filename in changed_system_designs: + task_doc = await self._update_tasks(filename=filename) + change_files.docs[filename] = task_doc + + # Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`. + for filename in changed_tasks: + if filename in change_files.docs: + continue + task_doc = await self._update_tasks(filename=filename) + change_files.docs[filename] = task_doc + + if not change_files.docs: + logger.info("Nothing has changed.") + # Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for + # global optimization in subsequent steps. + return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files) + + async def _update_tasks(self, filename): + system_design_doc = await self.repo.docs.system_design.get(filename) + task_doc = await self.repo.docs.task.get(filename) + if task_doc: + task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc) + await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path}) + else: + rsp = await self._run_new_tasks(context=system_design_doc.content) + task_doc = await self.repo.docs.task.save( + filename=filename, + content=rsp.instruct_content.model_dump_json(), + dependencies={system_design_doc.root_relative_path}, + ) + await self._update_requirements(task_doc) + return task_doc + + async def _run_new_tasks(self, context): + node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) + return node + + async def _merge(self, system_design_doc, task_doc) -> Document: + context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content) + node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema) + task_doc.content = node.instruct_content.model_dump_json() + return task_doc + + async def _update_requirements(self, doc): + m = json.loads(doc.content) + packages = set(m.get("Required Python packages", set())) + requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME) + if not requirement_doc: + requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="") + lines = requirement_doc.content.splitlines() + for pkg in lines: + if pkg == "": + continue + packages.add(pkg) + await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages)) diff --git a/notebook_dir/metagpt_yusin/actions/project_management_an.py b/notebook_dir/metagpt_yusin/actions/project_management_an.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebecf76e6f18f85483f65d130b67f0f87a5842e --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/project_management_an.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/14 15:28 +@Author : alexanderwu +@File : project_management_an.py +""" +from typing import List + +from metagpt_yusin.actions.action_node import ActionNode + +REQUIRED_PYTHON_PACKAGES = ActionNode( + key="Required Python packages", + expected_type=List[str], + instruction="Provide required Python packages in requirements.txt format.", + example=["flask==1.1.2", "bcrypt==3.2.0"], +) + +REQUIRED_OTHER_LANGUAGE_PACKAGES = ActionNode( + key="Required Other language third-party packages", + expected_type=List[str], + instruction="List down the required packages for languages other than Python.", + example=["No third-party dependencies required"], +) + +LOGIC_ANALYSIS = ActionNode( + key="Logic Analysis", + expected_type=List[List[str]], + instruction="Provide a list of files with the classes/methods/functions to be implemented, " + "including dependency analysis and imports.", + example=[ + ["game.py", "Contains Game class and ... functions"], + ["main.py", "Contains main function, from game import Game"], + ], +) + +REFINED_LOGIC_ANALYSIS = ActionNode( + key="Refined Logic Analysis", + expected_type=List[List[str]], + instruction="Review and refine the logic analysis by merging the Legacy Content and Incremental Content. " + "Provide a comprehensive list of files with classes/methods/functions to be implemented or modified incrementally. " + "Include dependency analysis, consider potential impacts on existing code, and document necessary imports.", + example=[ + ["game.py", "Contains Game class and ... functions"], + ["main.py", "Contains main function, from game import Game"], + ["new_feature.py", "Introduces NewFeature class and related functions"], + ["utils.py", "Modifies existing utility functions to support incremental changes"], + ], +) + +TASK_LIST = ActionNode( + key="Task list", + expected_type=List[str], + instruction="Break down the tasks into a list of filenames, prioritized by dependency order.", + example=["game.py", "main.py"], +) + +REFINED_TASK_LIST = ActionNode( + key="Refined Task list", + expected_type=List[str], + instruction="Review and refine the combined task list after the merger of Legacy Content and Incremental Content, " + "and consistent with Refined File List. Ensure that tasks are organized in a logical and prioritized order, " + "considering dependencies for a streamlined and efficient development process. ", + example=["new_feature.py", "utils", "game.py", "main.py"], +) + +FULL_API_SPEC = ActionNode( + key="Full API spec", + expected_type=str, + instruction="Describe all APIs using OpenAPI 3.0 spec that may be used by both frontend and backend. If front-end " + "and back-end communication is not required, leave it blank.", + example="openapi: 3.0.0 ...", +) + +SHARED_KNOWLEDGE = ActionNode( + key="Shared Knowledge", + expected_type=str, + instruction="Detail any shared knowledge, like common utility functions or configuration variables.", + example="`game.py` contains functions shared across the project.", +) + +REFINED_SHARED_KNOWLEDGE = ActionNode( + key="Refined Shared Knowledge", + expected_type=str, + instruction="Update and expand shared knowledge to reflect any new elements introduced. This includes common " + "utility functions, configuration variables for team collaboration. Retain content that is not related to " + "incremental development but important for consistency and clarity.", + example="`new_module.py` enhances shared utility functions for improved code reusability and collaboration.", +) + + +ANYTHING_UNCLEAR_PM = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any unclear aspects in the project management context and try to clarify them.", + example="Clarification needed on how to start and initialize third-party libraries.", +) + +NODES = [ + REQUIRED_PYTHON_PACKAGES, + REQUIRED_OTHER_LANGUAGE_PACKAGES, + LOGIC_ANALYSIS, + TASK_LIST, + FULL_API_SPEC, + SHARED_KNOWLEDGE, + ANYTHING_UNCLEAR_PM, +] + +REFINED_NODES = [ + REQUIRED_PYTHON_PACKAGES, + REQUIRED_OTHER_LANGUAGE_PACKAGES, + REFINED_LOGIC_ANALYSIS, + REFINED_TASK_LIST, + FULL_API_SPEC, + REFINED_SHARED_KNOWLEDGE, + ANYTHING_UNCLEAR_PM, +] + +PM_NODE = ActionNode.from_children("PM_NODE", NODES) +REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES) diff --git a/notebook_dir/metagpt_yusin/actions/rebuild_class_view.py b/notebook_dir/metagpt_yusin/actions/rebuild_class_view.py new file mode 100644 index 0000000000000000000000000000000000000000..121a120627631fefef976c6ae7d171f06ddaae46 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/rebuild_class_view.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : rebuild_class_view.py +@Desc : Reconstructs class diagram from a source code project. + Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt +""" + +from pathlib import Path +from typing import Optional, Set, Tuple + +import aiofiles + +from metagpt_yusin.actions import Action +from metagpt_yusin.config2 import config +from metagpt_yusin.const import ( + AGGREGATION, + COMPOSITION, + DATA_API_DESIGN_FILE_REPO, + GENERALIZATION, + GRAPH_REPO_FILE_REPO, +) +from metagpt_yusin.logs import logger +from metagpt_yusin.repo_parser import DotClassInfo, RepoParser +from metagpt_yusin.schema import UMLClassView +from metagpt_yusin.utils.common import concat_namespace, split_namespace +from metagpt_yusin.utils.di_graph_repository import DiGraphRepository +from metagpt_yusin.utils.graph_repository import GraphKeyword, GraphRepository + + +class RebuildClassView(Action): + """ + Reconstructs a graph repository about class diagram from a source code project. + + Attributes: + graph_db (Optional[GraphRepository]): The optional graph repository. + """ + + graph_db: Optional[GraphRepository] = None + + async def run(self, with_messages=None, format=config.prompt_schema): + """ + Implementation of `Action`'s `run` method. + + Args: + with_messages (Optional[Type]): An optional argument specifying messages to react to. + format (str): The format for the prompt schema. + """ + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + repo_parser = RepoParser(base_directory=Path(self.i_context)) + # use pylint + class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context)) + await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views) + await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views) + await GraphRepository.rebuild_composition_relationship(self.graph_db) + # use ast + direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root) + symbols = repo_parser.generate_symbols() + for file_info in symbols: + # Align to the same root directory in accordance with `class_views`. + file_info.file = self._align_root(file_info.file, direction, diff_path) + await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info) + await self._create_mermaid_class_views() + await self.graph_db.save() + + async def _create_mermaid_class_views(self) -> str: + """Creates a Mermaid class diagram using data from the `graph_db` graph repository. + + This method utilizes information stored in the graph repository to generate a Mermaid class diagram. + Returns: + mermaid class diagram file name. + """ + path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO + path.mkdir(parents=True, exist_ok=True) + pathname = path / self.context.git_repo.workdir.name + filename = str(pathname.with_suffix(".class_diagram.mmd")) + async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer: + content = "classDiagram\n" + logger.debug(content) + await writer.write(content) + # class names + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + class_distinct = set() + relationship_distinct = set() + for r in rows: + content = await self._create_mermaid_class(r.subject) + if content: + await writer.write(content) + class_distinct.add(r.subject) + for r in rows: + content, distinct = await self._create_mermaid_relationship(r.subject) + if content: + logger.debug(content) + await writer.write(content) + relationship_distinct.update(distinct) + logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}") + + if self.i_context: + r_filename = Path(filename).relative_to(self.context.git_repo.workdir) + await self.graph_db.insert( + subject=self.i_context, predicate="hasMermaidClassDiagramFile", object_=str(r_filename) + ) + logger.info(f"{self.i_context} hasMermaidClassDiagramFile {filename}") + return filename + + async def _create_mermaid_class(self, ns_class_name) -> str: + """Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository. + + Args: + ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created. + + Returns: + str: A Mermaid code block object in markdown representing the class diagram. + """ + fields = split_namespace(ns_class_name) + if len(fields) > 2: + # Ignore sub-class + return "" + + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL) + if not rows: + return "" + dot_class_info = DotClassInfo.model_validate_json(rows[0].object_) + class_view = UMLClassView.load_dot_class_info(dot_class_info) + + # update uml view + await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json()) + # update uml isCompositeOf + for c in dot_class_info.compositions: + await self.graph_db.insert( + subject=ns_class_name, + predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF, + object_=concat_namespace("?", c), + ) + + # update uml isAggregateOf + for a in dot_class_info.aggregations: + await self.graph_db.insert( + subject=ns_class_name, + predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF, + object_=concat_namespace("?", a), + ) + + content = class_view.get_mermaid(align=1) + logger.debug(content) + return content + + async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[Optional[str], Optional[Set]]: + """Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository. + + Args: + ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created. + + Returns: + Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication. + """ + s_fields = split_namespace(ns_class_name) + if len(s_fields) > 2: + # Ignore sub-class + return None, None + + predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]} + mappings = { + GENERALIZATION: " <|-- ", + COMPOSITION: " *-- ", + AGGREGATION: " o-- ", + } + content = "" + distinct = set() + for p, v in predicates.items(): + rows = await self.graph_db.select(subject=ns_class_name, predicate=p) + for r in rows: + o_fields = split_namespace(r.object_) + if len(o_fields) > 2: + # Ignore sub-class + continue + relationship = mappings.get(v, " .. ") + link = f"{o_fields[1]}{relationship}{s_fields[1]}" + distinct.add(link) + content += f"\t{link}\n" + + return content, distinct + + @staticmethod + def _diff_path(path_root: Path, package_root: Path) -> (str, str): + """Returns the difference between the root path and the path information represented in the package name. + + Args: + path_root (Path): The root path. + package_root (Path): The package root path. + + Returns: + Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part. + + Example: + >>> _diff_path(path_root=Path("/Users/x/github/metagpt_yusin"), package_root=Path("/Users/x/github/metagpt_yusin/metagpt_yusin")) + "-", "metagpt_yusin" + + >>> _diff_path(path_root=Path("/Users/x/github/metagpt_yusin/metagpt_yusin"), package_root=Path("/Users/x/github/metagpt_yusin/metagpt_yusin")) + "=", "." + """ + if len(str(path_root)) > len(str(package_root)): + return "+", str(path_root.relative_to(package_root)) + if len(str(path_root)) < len(str(package_root)): + return "-", str(package_root.relative_to(path_root)) + return "=", "." + + @staticmethod + def _align_root(path: str, direction: str, diff_path: str) -> str: + """Aligns the path to the same root represented by `diff_path`. + + Args: + path (str): The path to be aligned. + direction (str): The direction of alignment ('+', '-', '='). + diff_path (str): The path representing the difference. + + Returns: + str: The aligned path. + + Example: + >>> _align_root(path="metagpt_yusin/software_company.py", direction="+", diff_path="metagpt_yusin") + "metagpt_yusin/metagpt_yusin/software_company.py" + + >>> _align_root(path="metagpt_yusin/software_company.py", direction="-", diff_path="metagpt_yusin") + "software_company.py" + """ + if direction == "=": + return path + if direction == "+": + return diff_path + "/" + path + else: + return path[len(diff_path) + 1 :] diff --git a/notebook_dir/metagpt_yusin/actions/rebuild_sequence_view.py b/notebook_dir/metagpt_yusin/actions/rebuild_sequence_view.py new file mode 100644 index 0000000000000000000000000000000000000000..85c3458610e95d744680898e19b608df1c3c0665 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/rebuild_sequence_view.py @@ -0,0 +1,613 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 +@Author : mashenquan +@File : rebuild_sequence_view.py +@Desc : Reconstruct sequence view information through reverse engineering. + Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt +""" +from __future__ import annotations + +import re +from datetime import datetime +from pathlib import Path +from typing import List, Optional, Set + +from pydantic import BaseModel +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt_yusin.actions import Action +from metagpt_yusin.config2 import config +from metagpt_yusin.const import GRAPH_REPO_FILE_REPO +from metagpt_yusin.logs import logger +from metagpt_yusin.repo_parser import CodeBlockInfo, DotClassInfo +from metagpt_yusin.schema import UMLClassView +from metagpt_yusin.utils.common import ( + add_affix, + aread, + auto_namespace, + concat_namespace, + general_after_log, + list_files, + parse_json_code_block, + read_file_block, + split_namespace, +) +from metagpt_yusin.utils.di_graph_repository import DiGraphRepository +from metagpt_yusin.utils.graph_repository import SPO, GraphKeyword, GraphRepository + + +class ReverseUseCase(BaseModel): + """ + Represents a reverse engineered use case. + + Attributes: + description (str): A description of the reverse use case. + inputs (List[str]): List of inputs for the reverse use case. + outputs (List[str]): List of outputs for the reverse use case. + actors (List[str]): List of actors involved in the reverse use case. + steps (List[str]): List of steps for the reverse use case. + reason (str): The reason behind the reverse use case. + """ + + description: str + inputs: List[str] + outputs: List[str] + actors: List[str] + steps: List[str] + reason: str + + +class ReverseUseCaseDetails(BaseModel): + """ + Represents details of a reverse engineered use case. + + Attributes: + description (str): A description of the reverse use case details. + use_cases (List[ReverseUseCase]): List of reverse use cases. + relationship (List[str]): List of relationships associated with the reverse use case details. + """ + + description: str + use_cases: List[ReverseUseCase] + relationship: List[str] + + +class RebuildSequenceView(Action): + """ + Represents an action to reconstruct sequence view through reverse engineering. + + Attributes: + graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations. + """ + + graph_db: Optional[GraphRepository] = None + + async def run(self, with_messages=None, format=config.prompt_schema): + """ + Implementation of `Action`'s `run` method. + + Args: + with_messages (Optional[Type]): An optional argument specifying messages to react to. + format (str): The format for the prompt schema. + """ + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + if not self.i_context: + entries = await self._search_main_entry() + else: + entries = [SPO(subject=self.i_context, predicate="", object_="")] + for entry in entries: + await self._rebuild_main_sequence_view(entry) + while await self._merge_sequence_view(entry): + pass + await self.graph_db.save() + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _rebuild_main_sequence_view(self, entry: SPO): + """ + Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering. + + Args: + entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the + subject `__name__:__main__`. + """ + filename = entry.subject.split(":", 1)[0] + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + classes = [] + prefix = filename + ":" + for r in rows: + if prefix in r.subject: + classes.append(r) + await self._rebuild_use_case(r.subject) + participants = await self._search_participants(split_namespace(entry.subject)[0]) + class_details = [] + class_views = [] + for c in classes: + detail = await self._get_class_detail(c.subject) + if not detail: + continue + class_details.append(detail) + view = await self._get_uml_class_view(c.subject) + if view: + class_views.append(view) + + actors = await self._get_participants(c.subject) + participants.update(set(actors)) + + use_case_blocks = [] + for c in classes: + use_cases = await self._get_class_use_cases(c.subject) + use_case_blocks.append(use_cases) + prompt_blocks = ["## Use Cases\n" + "\n".join(use_case_blocks)] + block = "## Participants\n" + for p in participants: + block += f"- {p}\n" + prompt_blocks.append(block) + block = "## Mermaid Class Views\n```mermaid\n" + block += "\n\n".join([c.get_mermaid() for c in class_views]) + block += "\n```\n" + prompt_blocks.append(block) + block = "## Source Code\n```python\n" + block += await self._get_source_code(filename) + block += "\n```\n" + prompt_blocks.append(block) + prompt = "\n---\n".join(prompt_blocks) + + rsp = await self.llm.aask( + msg=prompt, + system_msgs=[ + "You are a python code to Mermaid Sequence Diagram translator in function detail.", + "Translate the given markdown text to a Mermaid Sequence Diagram.", + "Return the merged Mermaid sequence diagram in a markdown code block format.", + ], + stream=False, + ) + sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + for r in rows: + if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW: + await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_) + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + await self.graph_db.insert( + subject=entry.subject, + predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER, + object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)), + ) + for c in classes: + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject) + ) + await self._save_sequence_view(subject=entry.subject, content=sequence_view) + + async def _merge_sequence_view(self, entry: SPO) -> bool: + """ + Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram. + + Args: + entry (SPO): The SPO object representing the relationship in the graph database. + + Returns: + bool: True if additional information has been augmented, otherwise False. + """ + new_participant = await self._search_new_participant(entry) + if not new_participant: + return False + + await self._merge_participant(entry, new_participant) + return True + + async def _search_main_entry(self) -> List: + """ + Asynchronously searches for the SPO object that is related to `__name__:__main__`. + + Returns: + List: A list containing information about the main entry in the sequence diagram. + """ + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) + tag = "__name__:__main__" + entries = [] + for r in rows: + if tag in r.subject or tag in r.object_: + entries.append(r) + return entries + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _rebuild_use_case(self, ns_class_name: str): + """ + Asynchronously reconstructs the use case for the provided namespace-prefixed class name. + + Args: + ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE) + if rows: + return + + detail = await self._get_class_detail(ns_class_name) + if not detail: + return + participants = set() + participants.update(set(detail.compositions)) + participants.update(set(detail.aggregations)) + class_view = await self._get_uml_class_view(ns_class_name) + source_code = await self._get_source_code(ns_class_name) + + # prompt_blocks = [ + # "## Instruction\n" + # "You are a python code to UML 2.0 Use Case translator.\n" + # 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n' + # "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not " + # 'conflict with the information in "Mermaid Class Views".\n' + # 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external ' + # "system interactions with the internal system.\n" + # ] + prompt_blocks = [] + block = "## Participants\n" + for p in participants: + block += f"- {p}\n" + prompt_blocks.append(block) + block = "## Mermaid Class Views\n```mermaid\n" + block += class_view.get_mermaid() + block += "\n```\n" + prompt_blocks.append(block) + block = "## Source Code\n```python\n" + block += source_code + block += "\n```\n" + prompt_blocks.append(block) + prompt = "\n---\n".join(prompt_blocks) + + rsp = await self.llm.aask( + msg=prompt, + system_msgs=[ + "You are a python code to UML 2.0 Use Case translator.", + 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".', + "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not " + 'conflict with the information in "Mermaid Class Views".', + 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external ' + "system interactions with the internal system.", + "Return a markdown JSON object with:\n" + '- a "description" key to explain what the whole source code want to do;\n' + '- a "use_cases" key list all use cases, each use case in the list should including a `description` ' + "key describes about what the use case to do, a `inputs` key lists the input names of the use case " + "from external sources, a `outputs` key lists the output names of the use case to external sources, " + "a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how " + "the use case works step by step, a `reason` key explaining under what circumstances would the " + "external system execute this use case.\n" + '- a "relationship" key lists all the descriptions of relationship among these use cases.\n', + ], + stream=False, + ) + + code_blocks = parse_json_code_block(rsp) + for block in code_blocks: + detail = ReverseUseCaseDetails.model_validate_json(block) + await self.graph_db.insert( + subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json() + ) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _rebuild_sequence_view(self, ns_class_name: str): + """ + Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name. + + Args: + ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed. + """ + await self._rebuild_use_case(ns_class_name) + + prompts_blocks = [] + use_case_markdown = await self._get_class_use_cases(ns_class_name) + if not use_case_markdown: # external class + await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="") + return + block = f"## Use Cases\n{use_case_markdown}" + prompts_blocks.append(block) + + participants = await self._get_participants(ns_class_name) + block = "## Participants\n" + "\n".join([f"- {s}" for s in participants]) + prompts_blocks.append(block) + + view = await self._get_uml_class_view(ns_class_name) + block = "## Mermaid Class Views\n```mermaid\n" + block += view.get_mermaid() + block += "\n```\n" + prompts_blocks.append(block) + + block = "## Source Code\n```python\n" + block += await self._get_source_code(ns_class_name) + block += "\n```\n" + prompts_blocks.append(block) + prompt = "\n---\n".join(prompts_blocks) + + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a Mermaid Sequence Diagram translator in function detail.", + "Translate the markdown text to a Mermaid Sequence Diagram.", + "Return a markdown mermaid code block.", + ], + stream=False, + ) + + sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") + await self.graph_db.insert( + subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + + async def _get_participants(self, ns_class_name: str) -> List[str]: + """ + Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO + object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list. + + Returns: + List[str]: A list of participants in the sequence diagram. + """ + participants = set() + detail = await self._get_class_detail(ns_class_name) + if not detail: + return [] + participants.update(set(detail.compositions)) + participants.update(set(detail.aggregations)) + return list(participants) + + async def _get_class_use_cases(self, ns_class_name: str) -> str: + """ + Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information. + + Returns: + str: A string containing the assembled context about the use case information. + """ + block = "" + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE) + for i, r in enumerate(rows): + detail = ReverseUseCaseDetails.model_validate_json(r.object_) + block += f"\n### {i + 1}. {detail.description}" + for j, use_case in enumerate(detail.use_cases): + block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n" + block += "\n##### Inputs\n" + "\n".join([f"- {s}" for s in use_case.inputs]) + block += "\n##### Outputs\n" + "\n".join([f"- {s}" for s in use_case.outputs]) + block += "\n##### Actors\n" + "\n".join([f"- {s}" for s in use_case.actors]) + block += "\n##### Steps\n" + "\n".join([f"- {s}" for s in use_case.steps]) + block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship]) + return block + "\n" + + async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None: + """ + Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve class details. + + Returns: + Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details, + or None if the details are not available. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL) + if not rows: + return None + dot_class_info = DotClassInfo.model_validate_json(rows[0].object_) + return dot_class_info + + async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None: + """ + Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details. + + Returns: + Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details, + or None if the details are not available. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW) + if not rows: + return None + class_view = UMLClassView.model_validate_json(rows[0].object_) + return class_view + + async def _get_source_code(self, ns_class_name: str) -> str: + """ + Asynchronously retrieves the source code of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code. + + Returns: + str: A string containing the source code of the specified namespace-prefixed class. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO) + filename = split_namespace(ns_class_name=ns_class_name)[0] + if not rows: + src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename) + if not src_filename: + return "" + return await aread(filename=src_filename, encoding="utf-8") + code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_) + return await read_file_block( + filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno + ) + + @staticmethod + def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: + """ + Convert package name to the full path of the module. + + Args: + root (Union[str, Path]): The root path or string representing the package. + pathname (Union[str, Path]): The pathname or string representing the module. + + Returns: + Union[Path, None]: The full path of the module, or None if the path cannot be determined. + + Examples: + If `root`(workdir) is "/User/xxx/github/metagpt_yusin/metagpt_yusin", and the `pathname` is + "metagpt_yusin/management/skill_manager.py", then the returned value will be + "/User/xxx/github/metagpt_yusin/metagpt_yusin/management/skill_manager.py" + """ + if re.match(r"^/.+", pathname): + return pathname + files = list_files(root=root) + postfix = "/" + str(pathname) + for i in files: + if str(i).endswith(postfix): + return i + return None + + @staticmethod + def parse_participant(mermaid_sequence_diagram: str) -> List[str]: + """ + Parses the provided Mermaid sequence diagram and returns the list of participants. + + Args: + mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed. + + Returns: + List[str]: A list of participants extracted from the sequence diagram. + """ + pattern = r"participant ([\w\.]+)" + matches = re.findall(pattern, mermaid_sequence_diagram) + matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches] + return matches + + async def _search_new_participant(self, entry: SPO) -> str | None: + """ + Asynchronously retrieves a participant whose sequence diagram has not been augmented. + + Args: + entry (SPO): The SPO object representing the relationship in the graph database. + + Returns: + Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found. + """ + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + if not rows: + return None + sequence_view = rows[0].object_ + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT) + merged_participants = [] + for r in rows: + name = split_namespace(r.object_)[-1] + merged_participants.append(name) + participants = self.parse_participant(sequence_view) + for p in participants: + if p in merged_participants: + continue + return p + return None + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _merge_participant(self, entry: SPO, class_name: str): + """ + Augments the sequence diagram of `class_name` to the sequence diagram of `entry`. + + Args: + entry (SPO): The SPO object representing the base sequence diagram. + class_name (str): The class name whose sequence diagram is to be augmented. + """ + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + participants = [] + for r in rows: + name = split_namespace(r.subject)[-1] + if name == class_name: + participants.append(r) + if len(participants) == 0: # external participants + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name) + ) + return + if len(participants) > 1: + for r in participants: + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject) + ) + return + + participant = participants[0] + await self._rebuild_sequence_view(participant.subject) + sequence_views = await self.graph_db.select( + subject=participant.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW + ) + if not sequence_views: # external class + return + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + prompt = f"```mermaid\n{sequence_views[0].object_}\n```\n---\n```mermaid\n{rows[0].object_}\n```" + + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool to merge sequence diagrams into one.", + "Participants with the same name are considered identical.", + "Return the merged Mermaid sequence diagram in a markdown code block format.", + ], + stream=False, + ) + + sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + for r in rows: + await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_) + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + await self.graph_db.insert( + subject=entry.subject, + predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER, + object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)), + ) + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject) + ) + await self._save_sequence_view(subject=entry.subject, content=sequence_view) + + async def _save_sequence_view(self, subject: str, content: str): + pattern = re.compile(r"[^a-zA-Z0-9]") + name = re.sub(pattern, "_", subject) + filename = Path(name).with_suffix(".sequence_diagram.mmd") + await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content) + + async def _search_participants(self, filename: str) -> Set: + content = await self._get_source_code(filename) + + rsp = await self.llm.aask( + msg=content, + system_msgs=[ + "You are a tool for listing all class names used in a source file.", + "Return a markdown JSON object with: " + '- a "class_names" key containing the list of class names used in the file; ' + '- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.', + ], + ) + + class _Data(BaseModel): + class_names: List[str] + reasons: List + + json_blocks = parse_json_code_block(rsp) + data = _Data.model_validate_json(json_blocks[0]) + return set(data.class_names) diff --git a/notebook_dir/metagpt_yusin/actions/research.py b/notebook_dir/metagpt_yusin/actions/research.py new file mode 100644 index 0000000000000000000000000000000000000000..b16b294c8ac30d943cf88c8e87d2c521ea8e246b --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/research.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import asyncio +from typing import Any, Callable, Optional, Union + +from pydantic import TypeAdapter, model_validator + +from metagpt_yusin.actions import Action +from metagpt_yusin.config2 import config +from metagpt_yusin.logs import logger +from metagpt_yusin.tools.search_engine import SearchEngine +from metagpt_yusin.tools.web_browser_engine import WebBrowserEngine +from metagpt_yusin.utils.common import OutputParser +from metagpt_yusin.utils.text import generate_prompt_chunk, reduce_message_length + +LANG_PROMPT = "Please respond in {language}." + +RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \ +written, critically acclaimed, objective and structured reports on the given text.""" + +RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}" + +SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \ +Your response must be in JSON format, for example: ["keyword1", "keyword2"].""" + +SUMMARIZE_SEARCH_PROMPT = """### Requirements +1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section. +2. Provide up to {decomposition_nums} queries related to your research topic base on the search results. +3. Please respond in the following JSON format: ["query1", "query2", "query3", ...]. + +### Search Result Information +{search_results} +""" + +COLLECT_AND_RANKURLS_PROMPT = """### Topic +{topic} +### Query +{query} + +### The online search results +{results} + +### Requirements +Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \ +based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the +ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words. +""" + +WEB_BROWSE_AND_SUMMARIZE_PROMPT = """### Requirements +1. Utilize the text in the "Reference Information" section to respond to the question "{query}". +2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \ +a comprehensive summary of the text. +3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant." +4. Include all relevant factual information, numbers, statistics, etc., if available. + +### Reference Information +{content} +""" + + +CONDUCT_RESEARCH_PROMPT = """### Reference Information +{content} + +### Requirements +Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \ +above. The report must meet the following requirements: + +- Focus on directly addressing the chosen topic. +- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available. +- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable. +- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines. +- Include all source URLs in APA format at the end of the report. +""" + + +class CollectLinks(Action): + """Action class to collect links from a search engine.""" + + name: str = "CollectLinks" + i_context: Optional[str] = None + desc: str = "Collect links from a search engine." + search_func: Optional[Any] = None + search_engine: Optional[SearchEngine] = None + rank_func: Optional[Callable[[list[str]], None]] = None + + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.search_engine is None: + self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy) + return self + + async def run( + self, + topic: str, + decomposition_nums: int = 4, + url_per_query: int = 4, + system_text: str | None = None, + ) -> dict[str, list[str]]: + """Run the action to collect links. + + Args: + topic: The research topic. + decomposition_nums: The number of search questions to generate. + url_per_query: The number of URLs to collect per search question. + system_text: The system text. + + Returns: + A dictionary containing the search questions as keys and the collected URLs as values. + """ + system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic) + keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) + try: + keywords = OutputParser.extract_struct(keywords, list) + keywords = TypeAdapter(list[str]).validate_python(keywords) + except Exception as e: + logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}") + keywords = [topic] + results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords)) + + def gen_msg(): + while True: + search_results = "\n".join( + f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results) + ) + prompt = SUMMARIZE_SEARCH_PROMPT.format( + decomposition_nums=decomposition_nums, search_results=search_results + ) + yield prompt + remove = max(results, key=len) + remove.pop() + if len(remove) == 0: + break + + model_name = config.llm.model + prompt = reduce_message_length(gen_msg(), model_name, system_text, config.llm.max_token) + logger.debug(prompt) + queries = await self._aask(prompt, [system_text]) + try: + queries = OutputParser.extract_struct(queries, list) + queries = TypeAdapter(list[str]).validate_python(queries) + except Exception as e: + logger.exception(f"fail to break down the research question due to {e}") + queries = keywords + ret = {} + for query in queries: + ret[query] = await self._search_and_rank_urls(topic, query, url_per_query) + return ret + + async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]: + """Search and rank URLs based on a query. + + Args: + topic: The research topic. + query: The search query. + num_results: The number of URLs to collect. + + Returns: + A list of ranked URLs. + """ + max_results = max(num_results * 2, 6) + results = await self.search_engine.run(query, max_results=max_results, as_string=False) + _results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results)) + prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results) + logger.debug(prompt) + indices = await self._aask(prompt) + try: + indices = OutputParser.extract_struct(indices, list) + assert all(isinstance(i, int) for i in indices) + except Exception as e: + logger.exception(f"fail to rank results for {e}") + indices = list(range(max_results)) + results = [results[i] for i in indices] + if self.rank_func: + results = self.rank_func(results) + return [i["link"] for i in results[:num_results]] + + +class WebBrowseAndSummarize(Action): + """Action class to explore the web and provide summaries of articles and webpages.""" + + name: str = "WebBrowseAndSummarize" + i_context: Optional[str] = None + desc: str = "Explore the web and provide summaries of articles and webpages." + browse_func: Union[Callable[[list[str]], None], None] = None + web_browser_engine: Optional[WebBrowserEngine] = None + + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.web_browser_engine is None: + self.web_browser_engine = WebBrowserEngine.from_browser_config( + self.config.browser, + browse_func=self.browse_func, + proxy=self.config.proxy, + ) + return self + + async def run( + self, + url: str, + *urls: str, + query: str, + system_text: str = RESEARCH_BASE_SYSTEM, + ) -> dict[str, str]: + """Run the action to browse the web and provide summaries. + + Args: + url: The main URL to browse. + urls: Additional URLs to browse. + query: The research question. + system_text: The system text. + + Returns: + A dictionary containing the URLs as keys and their summaries as values. + """ + contents = await self.web_browser_engine.run(url, *urls) + if not urls: + contents = [contents] + + summaries = {} + prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}") + for u, content in zip([url, *urls], contents): + content = content.inner_text + chunk_summaries = [] + for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096): + logger.debug(prompt) + summary = await self._aask(prompt, [system_text]) + if summary == "Not relevant.": + continue + chunk_summaries.append(summary) + + if not chunk_summaries: + summaries[u] = None + continue + + if len(chunk_summaries) == 1: + summaries[u] = chunk_summaries[0] + continue + + content = "\n".join(chunk_summaries) + prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content) + summary = await self._aask(prompt, [system_text]) + summaries[u] = summary + return summaries + + +class ConductResearch(Action): + """Action class to conduct research and generate a research report.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def run( + self, + topic: str, + content: str, + system_text: str = RESEARCH_BASE_SYSTEM, + ) -> str: + """Run the action to conduct research and generate a research report. + + Args: + topic: The research topic. + content: The content for research. + system_text: The system text. + + Returns: + The generated research report. + """ + prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content) + logger.debug(prompt) + self.llm.auto_max_tokens = True + return await self._aask(prompt, [system_text]) + + +def get_research_system_text(topic: str, language: str): + """Get the system text for conducting research. + + Args: + topic: The research topic. + language: The language for the system text. + + Returns: + The system text for conducting research. + """ + return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language))) diff --git a/notebook_dir/metagpt_yusin/actions/run_code.py b/notebook_dir/metagpt_yusin/actions/run_code.py new file mode 100644 index 0000000000000000000000000000000000000000..89d7307bb59869e7e5aeb363ede9662f7643edee --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/run_code.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:46 +@Author : alexanderwu +@File : run_code.py +@Modified By: mashenquan, 2023/11/27. + 1. Mark the location of Console logs in the PROMPT_TEMPLATE with markdown code-block formatting to enhance + the understanding for the LLM. + 2. Fix bug: Add the "install dependency" operation. + 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into + RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. + 4. According to section 2.2.3.5.7 of RFC 135, change the method of transferring file content + (code files, unit test files, log files) from using the message to using the file name. + 5. Merged the `Config` class of send18:dev branch to take over the set/get operations of the Environment + class. +""" +import subprocess +from pathlib import Path +from typing import Tuple + +from pydantic import Field + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import RunCodeContext, RunCodeResult +from metagpt_yusin.utils.exceptions import handle_exception + +PROMPT_TEMPLATE = """ +Role: You are a senior development and qa engineer, your role is summarize the code running result. +If the running result does not include an error, you should explicitly approve the result. +On the other hand, if the running result indicates some error, you should point out which part, the development code or the test code, produces the error, +and give specific instructions on fixing the errors. Here is the code info: +{context} +Now you should begin your analysis +--- +## instruction: +Please summarize the cause of the errors and give correction instruction +## File To Rewrite: +Determine the ONE file to rewrite in order to fix the error, for example, xyz.py, or test_xyz.py +## Status: +Determine if all of the code works fine, if so write PASS, else FAIL, +WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION +## Send To: +Please write NoOne if there are no errors, Engineer if the errors are due to problematic development codes, else QaEngineer, +WRITE ONLY ONE WORD, NoOne OR Engineer OR QaEngineer, IN THIS SECTION. +--- +You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line. +""" + +TEMPLATE_CONTEXT = """ +## Development Code File Name +{code_file_name} +## Development Code +```python +{code} +``` +## Test File Name +{test_file_name} +## Test Code +```python +{test_code} +``` +## Running Command +{command} +## Running Output +standard output: +```text +{outs} +``` +standard errors: +```text +{errs} +``` +""" + + +class RunCode(Action): + name: str = "RunCode" + i_context: RunCodeContext = Field(default_factory=RunCodeContext) + + @classmethod + async def run_text(cls, code) -> Tuple[str, str]: + try: + # We will document_store the result in this dictionary + namespace = {} + exec(code, namespace) + except Exception as e: + return "", str(e) + return namespace.get("result", ""), "" + + async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]: + working_directory = str(working_directory) + additional_python_paths = [str(path) for path in additional_python_paths] + + # Copy the current environment variables + env = self.context.new_environ() + + # Modify the PYTHONPATH environment variable + additional_python_paths = [working_directory] + additional_python_paths + additional_python_paths = ":".join(additional_python_paths) + env["PYTHONPATH"] = additional_python_paths + ":" + env.get("PYTHONPATH", "") + RunCode._install_dependencies(working_directory=working_directory, env=env) + + # Start the subprocess + process = subprocess.Popen( + command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + logger.info(" ".join(command)) + + try: + # Wait for the process to complete, with a timeout + stdout, stderr = process.communicate(timeout=10) + except subprocess.TimeoutExpired: + logger.info("The command did not complete within the given timeout.") + process.kill() # Kill the process if it times out + stdout, stderr = process.communicate() + return stdout.decode("utf-8"), stderr.decode("utf-8") + + async def run(self, *args, **kwargs) -> RunCodeResult: + logger.info(f"Running {' '.join(self.i_context.command)}") + if self.i_context.mode == "script": + outs, errs = await self.run_script( + command=self.i_context.command, + working_directory=self.i_context.working_directory, + additional_python_paths=self.i_context.additional_python_paths, + ) + elif self.i_context.mode == "text": + outs, errs = await self.run_text(code=self.i_context.code) + + logger.info(f"{outs=}") + logger.info(f"{errs=}") + + context = TEMPLATE_CONTEXT.format( + code=self.i_context.code, + code_file_name=self.i_context.code_filename, + test_code=self.i_context.test_code, + test_file_name=self.i_context.test_filename, + command=" ".join(self.i_context.command), + outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow + errs=errs[:10000], # truncate errors to avoid token overflow + ) + + prompt = PROMPT_TEMPLATE.format(context=context) + rsp = await self._aask(prompt) + return RunCodeResult(summary=rsp, stdout=outs, stderr=errs) + + @staticmethod + @handle_exception(exception_type=subprocess.CalledProcessError) + def _install_via_subprocess(cmd, check, cwd, env): + return subprocess.run(cmd, check=check, cwd=cwd, env=env) + + @staticmethod + def _install_requirements(working_directory, env): + file_path = Path(working_directory) / "requirements.txt" + if not file_path.exists(): + return + if file_path.stat().st_size == 0: + return + install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"] + logger.info(" ".join(install_command)) + RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env) + + @staticmethod + def _install_pytest(working_directory, env): + install_pytest_command = ["python", "-m", "pip", "install", "pytest"] + logger.info(" ".join(install_pytest_command)) + RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env) + + @staticmethod + def _install_dependencies(working_directory, env): + RunCode._install_requirements(working_directory, env) + RunCode._install_pytest(working_directory, env) diff --git a/notebook_dir/metagpt_yusin/actions/search_and_summarize.py b/notebook_dir/metagpt_yusin/actions/search_and_summarize.py new file mode 100644 index 0000000000000000000000000000000000000000..b9afd17ae5f0f9e5c921cf4eee2ff800ee97bb5c --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/search_and_summarize.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/23 17:26 +@Author : alexanderwu +@File : search_google.py +""" +from typing import Optional + +import pydantic +from pydantic import model_validator + +from metagpt_yusin.actions import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message +from metagpt_yusin.tools.search_engine import SearchEngine + +SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements +1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. +- The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. +2. If there are citable links in the context, annotate them in the main text in the format [main text](citation link). If there are none in the context, do not write links. +3. The reply should be graceful, clear, non-repetitive, smoothly written, and of moderate length, in {LANG}. + +### Dialogue History (For example) +A: MLOps competitors + +### Current Question (For example) +A: MLOps competitors + +### Current Reply (For example) +1. Alteryx Designer: etc. if any +2. Matlab: ditto +3. IBM SPSS Statistics +4. RapidMiner Studio +5. DataRobot AI Platform +6. Databricks Lakehouse Platform +7. Amazon SageMaker +8. Dataiku +""" + +SEARCH_AND_SUMMARIZE_SYSTEM_EN_US = SEARCH_AND_SUMMARIZE_SYSTEM.format(LANG="en-us") + +SEARCH_AND_SUMMARIZE_PROMPT = """ +### Reference Information +{CONTEXT} + +### Dialogue History +{QUERY_HISTORY} +{QUERY} + +### Current Question +{QUERY} + +### Current Reply: Based on the information, please write the reply to the Question + + +""" + +SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements +1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. +- The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. +2. If there are citable links in the context, annotate them in the main text in the format [main text](citation link). If there are none in the context, do not write links. +3. The reply should be graceful, clear, non-repetitive, smoothly written, and of moderate length, in Simplified Chinese. + +# Example +## Reference Information +... + +## Dialogue History +user: Which facial cleanser is good for oily skin? +Salesperson: Hello, for oily skin, it is suggested to choose a product that can deeply cleanse, control oil, and is gentle and skin-friendly. According to customer feedback and market reputation, the following facial cleansers are recommended:... +user: Do you have any by L'Oreal? +> Salesperson: ... + +## Ideal Answer +Yes, I've selected the following for you: +1. L'Oreal Men's Facial Cleanser: Oil control, anti-acne, balance of water and oil, pore purification, effectively against blackheads, deep exfoliation, refuse oil shine. Dense foam, not tight after washing. +2. L'Oreal Age Perfect Hydrating Cleanser: Added with sodium cocoyl glycinate and Centella Asiatica, two effective ingredients, it can deeply cleanse, tighten the skin, gentle and not tight. +""" + +SEARCH_AND_SUMMARIZE_SALES_PROMPT = """ +## Reference Information +{CONTEXT} + +## Dialogue History +{QUERY_HISTORY} +{QUERY} +> {ROLE}: + +""" + +SEARCH_FOOD = """ +# User Search Request +What are some delicious foods in Xiamen? + +# Requirements +You are a member of a professional butler team and will provide helpful suggestions: +1. Please summarize the user's search request based on the context and avoid including unrelated text. +2. Use [main text](reference link) in markdown format to **naturally annotate** 3-5 textual elements (such as product words or similar text sections) within the main text for easy navigation. +3. The response should be elegant, clear, **without any repetition of text**, smoothly written, and of moderate length. +""" + + +class SearchAndSummarize(Action): + name: str = "" + content: Optional[str] = None + search_engine: SearchEngine = None + result: str = "" + + @model_validator(mode="after") + def validate_search_engine(self): + if self.search_engine is None: + try: + config = self.config + search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy) + except pydantic.ValidationError: + search_engine = None + + self.search_engine = search_engine + return self + + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: + if self.search_engine is None: + logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") + return "" + + query = context[-1].content + # logger.debug(query) + rsp = await self.search_engine.run(query) + self.result = rsp + if not rsp: + logger.error("empty rsp...") + return "" + # logger.info(rsp) + + system_prompt = [system_text] + + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( + ROLE=self.prefix, + CONTEXT=rsp, + QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]), + QUERY=str(context[-1]), + ) + result = await self._aask(prompt, system_prompt) + logger.debug(prompt) + logger.debug(result) + return result diff --git a/notebook_dir/metagpt_yusin/actions/skill_action.py b/notebook_dir/metagpt_yusin/actions/skill_action.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c38b9f8b94fb0302e9ac5f51cc997a00c692e9 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/skill_action.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : skill_action.py +@Desc : Call learned skill +""" +from __future__ import annotations + +import ast +import importlib +import traceback +from copy import deepcopy +from typing import Dict, Optional + +from metagpt_yusin.actions import Action +from metagpt_yusin.learn.skill_loader import Skill +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message + + +# TOTEST +class ArgumentsParingAction(Action): + skill: Skill + ask: str + rsp: Optional[Message] = None + args: Optional[Dict] = None + + @property + def prompt(self): + prompt = f"{self.skill.name} function parameters description:\n" + for k, v in self.skill.arguments.items(): + prompt += f"parameter `{k}`: {v}\n" + prompt += "\n---\n" + prompt += "Examples:\n" + for e in self.skill.examples: + prompt += f"If want you to do `{e.ask}`, return `{e.answer}` brief and clear.\n" + prompt += "\n---\n" + prompt += ( + f"\nRefer to the `{self.skill.name}` function description, and fill in the function parameters according " + 'to the example "I want you to do xx" in the Examples section.' + f"\nNow I want you to do `{self.ask}`, return function parameters in Examples format above, brief and " + "clear." + ) + return prompt + + async def run(self, with_message=None, **kwargs) -> Message: + prompt = self.prompt + rsp = await self.llm.aask( + msg=prompt, + system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."], + stream=False, + ) + logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}") + self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp) + self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self) + return self.rsp + + @staticmethod + def parse_arguments(skill_name, txt) -> dict: + prefix = skill_name + "(" + if prefix not in txt: + logger.error(f"{skill_name} not in {txt}") + return None + if ")" not in txt: + logger.error(f"')' not in {txt}") + return None + begin_ix = txt.find(prefix) + end_ix = txt.rfind(")") + args_txt = txt[begin_ix + len(prefix) : end_ix] + logger.info(args_txt) + fake_expression = f"dict({args_txt})" + parsed_expression = ast.parse(fake_expression, mode="eval") + args = {} + for keyword in parsed_expression.body.keywords: + key = keyword.arg + value = ast.literal_eval(keyword.value) + args[key] = value + return args + + +class SkillAction(Action): + skill: Skill + args: Dict + rsp: Optional[Message] = None + + async def run(self, with_message=None, **kwargs) -> Message: + """Run action""" + options = deepcopy(kwargs) + if self.args: + for k in self.args.keys(): + if k in options: + options.pop(k) + try: + rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options) + self.rsp = Message(content=rsp, role="assistant", cause_by=self) + except Exception as e: + logger.exception(f"{e}, traceback:{traceback.format_exc()}") + self.rsp = Message(content=f"Error: {e}", role="assistant", cause_by=self) + return self.rsp + + @staticmethod + async def find_and_call_function(function_name, args, **kwargs) -> str: + try: + module = importlib.import_module("metagpt_yusin.learn") + function = getattr(module, function_name) + # Invoke function and return result + result = await function(**args, **kwargs) + return result + except (ModuleNotFoundError, AttributeError): + logger.error(f"{function_name} not found") + raise ValueError(f"{function_name} not found") diff --git a/notebook_dir/metagpt_yusin/actions/summarize_code.py b/notebook_dir/metagpt_yusin/actions/summarize_code.py new file mode 100644 index 0000000000000000000000000000000000000000..45dfa6f0f6b6bdbe487e6891d3691d8e7005b68a --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/summarize_code.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : summarize_code.py +@Modified By: mashenquan, 2023/12/5. Archive the summarization content of issue discovery for use in WriteCode. +""" +from pathlib import Path + +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import CodeSummarizeContext + +PROMPT_TEMPLATE = """ +NOTICE +Role: You are a professional software engineer, and your main task is to review the code. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. +ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example". + +----- +# System Design +```text +{system_design} +``` +----- +# Task +```text +{task} +``` +----- +{code_blocks} + +## Code Review All: Please read all historical files and find possible bugs in the files, such as unimplemented functions, calling errors, unreferences, etc. + +## Call flow: mermaid code, based on the implemented function, use mermaid to draw a complete call chain + +## Summary: Summary based on the implementation of historical files + +## TODOs: Python dict[str, str], write down the list of files that need to be modified and the reasons. We will modify them later. + +""" + +FORMAT_EXAMPLE = """ + +## Code Review All + +### a.py +- It fulfills less of xxx requirements... +- Field yyy is not given... +-... + +### b.py +... + +### c.py +... + +## Call flow +```mermaid +flowchart TB + c1-->a2 + subgraph one + a1-->a2 + end + subgraph two + b1-->b2 + end + subgraph three + c1-->c2 + end +``` + +## Summary +- a.py:... +- b.py:... +- c.py:... +- ... + +## TODOs +{ + "a.py": "implement requirement xxx...", +} + +""" + + +class SummarizeCode(Action): + name: str = "SummarizeCode" + i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + + @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) + async def summarize_code(self, prompt): + code_rsp = await self._aask(prompt) + return code_rsp + + async def run(self): + design_pathname = Path(self.i_context.design_filename) + design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name) + task_pathname = Path(self.i_context.task_filename) + task_doc = await self.repo.docs.task.get(filename=task_pathname.name) + src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs + code_blocks = [] + for filename in self.i_context.codes_filenames: + code_doc = await src_file_repo.get(filename) + code_block = f"```python\n{code_doc.content}\n```\n-----" + code_blocks.append(code_block) + format_example = FORMAT_EXAMPLE + prompt = PROMPT_TEMPLATE.format( + system_design=design_doc.content, + task=task_doc.content, + code_blocks="\n".join(code_blocks), + format_example=format_example, + ) + logger.info("Summarize code..") + rsp = await self.summarize_code(prompt) + return rsp diff --git a/notebook_dir/metagpt_yusin/actions/talk_action.py b/notebook_dir/metagpt_yusin/actions/talk_action.py new file mode 100644 index 0000000000000000000000000000000000000000..37044660caaaac0306a072af4048fa992becfe66 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/talk_action.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : talk_action.py +@Desc : Act as it’s a talk +""" +from typing import Optional + +from metagpt_yusin.actions import Action +from metagpt_yusin.config2 import config +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message + + +class TalkAction(Action): + i_context: str + history_summary: str = "" + knowledge: str = "" + rsp: Optional[Message] = None + + @property + def agent_description(self): + return self.context.kwargs.agent_description + + @property + def language(self): + return self.context.kwargs.language or config.language + + @property + def prompt(self): + prompt = "" + if self.agent_description: + prompt = ( + f"You are {self.agent_description}. Your responses should align with the role-play agreement, " + f"maintaining the character's persona and habits. When faced with unrelated questions, playfully " + f"decline to answer without revealing your AI nature to preserve the character's image.\n\n" + ) + prompt += f"Knowledge:\n{self.knowledge}\n\n" if self.knowledge else "" + prompt += f"{self.history_summary}\n\n" + prompt += ( + "If the information is insufficient, you can search in the historical conversation or knowledge above.\n" + ) + language = self.language + prompt += ( + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n " + f"{self.i_context}" + ) + logger.debug(f"PROMPT: {prompt}") + return prompt + + @property + def prompt_gpt4(self): + kvs = { + "{role}": self.agent_description or "", + "{history}": self.history_summary or "", + "{knowledge}": self.knowledge or "", + "{language}": self.language, + "{ask}": self.i_context, + } + prompt = TalkActionPrompt.FORMATION_LOOSE + for k, v in kvs.items(): + prompt = prompt.replace(k, v) + logger.info(f"PROMPT: {prompt}") + return prompt + + # async def run_old(self, *args, **kwargs) -> ActionOutput: + # prompt = self.prompt + # rsp = await self.llm.aask(msg=prompt, system_msgs=[]) + # logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n") + # self._rsp = ActionOutput(content=rsp) + # return self._rsp + + @property + def aask_args(self): + language = self.language + system_msgs = [ + f"You are {self.agent_description}.", + "Your responses should align with the role-play agreement, " + "maintaining the character's persona and habits. When faced with unrelated questions, playfully " + "decline to answer without revealing your AI nature to preserve the character's image.", + "If the information is insufficient, you can search in the context or knowledge.", + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.", + ] + format_msgs = [] + if self.knowledge: + format_msgs.append({"role": "assistant", "content": self.knowledge}) + if self.history_summary: + format_msgs.append({"role": "assistant", "content": self.history_summary}) + return self.i_context, format_msgs, system_msgs + + async def run(self, with_message=None, **kwargs) -> Message: + msg, format_msgs, system_msgs = self.aask_args + rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False) + self.rsp = Message(content=rsp, role="assistant", cause_by=self) + return self.rsp + + +class TalkActionPrompt: + FORMATION = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "[ASK_BEGIN]" and [ASK_END] tags enclose the questions; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should align with the role-play agreement, maintaining the + character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing + your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" + + FORMATION_LOOSE = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions +, playfully decline to answer without revealing your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" diff --git a/notebook_dir/metagpt_yusin/actions/write_code.py b/notebook_dir/metagpt_yusin/actions/write_code.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f81e3511cdab42df8acf4d0c6fe6b011e6db7e --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_code.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_code.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `cause_by` + value of the `Message` object. +@Modified By: mashenquan, 2023-11-27. + 1. Mark the location of Design, Tasks, Legacy Code and Debug logs in the PROMPT_TEMPLATE with markdown + code-block formatting to enhance the understanding for the LLM. + 2. Following the think-act principle, solidify the task parameters when creating the WriteCode object, rather + than passing them in when calling the run function. + 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into + RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. +""" + +import json + +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST +from metagpt_yusin.actions.write_code_plan_and_change_an import REFINED_TEMPLATE +from metagpt_yusin.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import CodingContext, Document, RunCodeResult +from metagpt_yusin.utils.common import CodeParser +from metagpt_yusin.utils.project_repo import ProjectRepo + +PROMPT_TEMPLATE = """ +NOTICE +Role: You are a professional engineer; the main goal is to write google-style, elegant, modular, easy to read and maintain code +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. +ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example". + +# Context +## Design +{design} + +## Task +{task} + +## Legacy Code +```Code +{code} +``` + +## Debug logs +```text +{logs} + +{summary_log} +``` + +## Bug Feedback logs +```text +{feedback} +``` + +# Format example +## Code: {filename} +```python +## {filename} +... +``` + +# Instruction: Based on the context, follow "Format example", write code. + +## Code: {filename}. Write code with triple quoto, based on the following attentions and context. +1. Only One file: do your best to implement THIS ONLY ONE FILE. +2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets. +3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import. +4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design. +5. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE. +6. Before using a external variable/module, make sure you import it first. +7. Write out EVERY CODE DETAIL, DON'T LEAVE TODO. + +""" + + +class WriteCode(Action): + name: str = "WriteCode" + i_context: Document = Field(default_factory=Document) + + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) + async def write_code(self, prompt) -> str: + code_rsp = await self._aask(prompt) + code = CodeParser.parse_code(block="", text=code_rsp) + return code + + async def run(self, *args, **kwargs) -> CodingContext: + bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME) + coding_context = CodingContext.loads(self.i_context.content) + test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json") + requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + summary_doc = None + if coding_context.design_doc and coding_context.design_doc.filename: + summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename) + logs = "" + if test_doc: + test_detail = RunCodeResult.loads(test_doc.content) + logs = test_detail.stderr + + if bug_feedback: + code_context = coding_context.code_doc.content + elif self.config.inc: + code_context = await self.get_codes( + coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True + ) + else: + code_context = await self.get_codes( + coding_context.task_doc, + exclude=self.i_context.filename, + project_repo=self.repo.with_src_path(self.context.src_workspace), + ) + + if self.config.inc: + prompt = REFINED_TEMPLATE.format( + user_requirement=requirement_doc.content if requirement_doc else "", + code_plan_and_change=str(coding_context.code_plan_and_change_doc), + design=coding_context.design_doc.content if coding_context.design_doc else "", + task=coding_context.task_doc.content if coding_context.task_doc else "", + code=code_context, + logs=logs, + feedback=bug_feedback.content if bug_feedback else "", + filename=self.i_context.filename, + summary_log=summary_doc.content if summary_doc else "", + ) + else: + prompt = PROMPT_TEMPLATE.format( + design=coding_context.design_doc.content if coding_context.design_doc else "", + task=coding_context.task_doc.content if coding_context.task_doc else "", + code=code_context, + logs=logs, + feedback=bug_feedback.content if bug_feedback else "", + filename=self.i_context.filename, + summary_log=summary_doc.content if summary_doc else "", + ) + logger.info(f"Writing {coding_context.filename}..") + code = await self.write_code(prompt) + if not coding_context.code_doc: + # avoid root_path pydantic ValidationError if use WriteCode alone + root_path = self.context.src_workspace if self.context.src_workspace else "" + coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) + coding_context.code_doc.content = code + return coding_context + + @staticmethod + async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, use_inc: bool = False) -> str: + """ + Get codes for generating the exclude file in various scenarios. + + Attributes: + task_doc (Document): Document object of the task file. + exclude (str): The file to be generated. Specifies the filename to be excluded from the code snippets. + project_repo (ProjectRepo): ProjectRepo object of the project. + use_inc (bool): Indicates whether the scenario involves incremental development. Defaults to False. + + Returns: + str: Codes for generating the exclude file. + """ + if not task_doc: + return "" + if not task_doc.content: + task_doc = project_repo.docs.task.get(filename=task_doc.filename) + m = json.loads(task_doc.content) + code_filenames = m.get(TASK_LIST.key, []) if use_inc else m.get(REFINED_TASK_LIST.key, []) + codes = [] + src_file_repo = project_repo.srcs + + # Incremental development scenario + if use_inc: + src_files = src_file_repo.all_files + # Get the old workspace contained the old codes and old workspace are created in previous CodePlanAndChange + old_file_repo = project_repo.git_repo.new_file_repository(relative_path=project_repo.old_workspace) + old_files = old_file_repo.all_files + # Get the union of the files in the src and old workspaces + union_files_list = list(set(src_files) | set(old_files)) + for filename in union_files_list: + # Exclude the current file from the all code snippets + if filename == exclude: + # If the file is in the old workspace, use the old code + # Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and + # essential functionality is included for the project’s requirements + if filename in old_files and filename != "main.py": + # Use old code + doc = await old_file_repo.get(filename=filename) + # If the file is in the src workspace, skip it + else: + continue + codes.insert(0, f"-----Now, {filename} to be rewritten\n```{doc.content}```\n=====") + # The code snippets are generated from the src workspace + else: + doc = await src_file_repo.get(filename=filename) + # If the file does not exist in the src workspace, skip it + if not doc: + continue + codes.append(f"----- {filename}\n```{doc.content}```") + + # Normal scenario + else: + for filename in code_filenames: + # Exclude the current file to get the code snippets for generating the current file + if filename == exclude: + continue + doc = await src_file_repo.get(filename=filename) + if not doc: + continue + codes.append(f"----- {filename}\n```{doc.content}```") + + return "\n".join(codes) diff --git a/notebook_dir/metagpt_yusin/actions/write_code_an_draft.py b/notebook_dir/metagpt_yusin/actions/write_code_an_draft.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfd07b5981b6f32911017314e74ef02b48c4045 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_code_an_draft.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +import asyncio +from typing import List, Literal + +from metagpt_yusin.actions import Action +from metagpt_yusin.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced reviewer and critically assess the given output. Provide specific and" + " constructive feedback, highlighting areas for improvement and suggesting changes.", + example=[ + "The logic in the function `calculate_total` seems flawed. Shouldn't it consider the discount rate as well?", + "The TODO function is not implemented yet? Should we implement it before commit?", + ], +) + +REVIEW_RESULT = ActionNode( + key="ReviewResult", + expected_type=Literal["LGTM", "LBTM"], + instruction="LGTM/LBTM. If the code is fully implemented, " "give a LGTM, otherwise provide a LBTM.", + example="LBTM", +) + +NEXT_STEPS = ActionNode( + key="NextSteps", + expected_type=str, + instruction="Based on the code review outcome, suggest actionable steps. This can include code changes, " + "refactoring suggestions, or any follow-up tasks.", + example="""1. Refactor the `process_data` method to improve readability and efficiency. +2. Cover edge cases in the `validate_user` function. +3. Implement a the TODO in the `calculate_total` function. +4. Fix the `handle_events` method to update the game state only if a move is successful. + ```python + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + if event.type == pygame.KEYDOWN: + moved = False + if event.key == pygame.K_UP: + moved = self.game.move('UP') + elif event.key == pygame.K_DOWN: + moved = self.game.move('DOWN') + elif event.key == pygame.K_LEFT: + moved = self.game.move('LEFT') + elif event.key == pygame.K_RIGHT: + moved = self.game.move('RIGHT') + if moved: + # Update the game state only if a move was successful + self.render() + return True + ``` +""", +) + +WRITE_DRAFT = ActionNode( + key="WriteDraft", + expected_type=str, + instruction="Could you write draft code for move function in order to implement it?", + example="Draft: ...", +) + + +WRITE_FUNCTION = ActionNode( + key="WriteFunction", + expected_type=str, + instruction="write code for the function not implemented.", + example=""" +```Code +... +``` +""", +) + + +REWRITE_CODE = ActionNode( + key="RewriteCode", + expected_type=str, + instruction="""rewrite code based on the Review and Actions""", + example=""" +```python +## example.py +def calculate_total(price, quantity): + total = price * quantity +``` +""", +) + + +CODE_REVIEW_CONTEXT = """ +# System +Role: You are a professional software engineer, and your main task is to review and revise the code. You need to ensure that the code conforms to the google-style standards, is elegantly designed and modularized, easy to read and maintain. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. + +# Context +## System Design +{"Implementation approach": "我们将使用HTML、CSS和JavaScript来实现这个单机的响应式2048游戏。为了确保游戏性能流畅和响应式设计,我们会选择使用Vue.js框架,因为它易于上手且适合构建交互式界面。我们还将使用localStorage来记录玩家的最高分。", "File list": ["index.html", "styles.css", "main.js", "game.js", "storage.js"], "Data structures and interfaces": "classDiagram\ + class Game {\ + -board Array\ + -score Number\ + -bestScore Number\ + +constructor()\ + +startGame()\ + +move(direction: String)\ + +getBoard() Array\ + +getScore() Number\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Storage {\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Main {\ + +init()\ + +bindEvents()\ + }\ + Game --> Storage : uses\ + Main --> Game : uses", "Program call flow": "sequenceDiagram\ + participant M as Main\ + participant G as Game\ + participant S as Storage\ + M->>G: init()\ + G->>S: getBestScore()\ + S-->>G: return bestScore\ + M->>G: bindEvents()\ + M->>G: startGame()\ + loop Game Loop\ + M->>G: move(direction)\ + G->>S: setBestScore(score)\ + S-->>G: return\ + end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Tasks +{"Required Python packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Code Files +----- index.html + + + + + + 2048游戏 + + + + +
+

2048

+
+
+
分数
+
{{ score }}
+
+
+
最高分
+
{{ bestScore }}
+
+
+
+
+
+ {{ cell !== 0 ? cell : \'\' }} +
+
+
+ +
+ + + + + + + + +----- styles.css +/* styles.css */ +body, html { + margin: 0; + padding: 0; + font-family: \'Arial\', sans-serif; +} + +#app { + text-align: center; + font-size: 18px; + color: #776e65; +} + +h1 { + color: #776e65; + font-size: 72px; + font-weight: bold; + margin: 20px 0; +} + +.scores-container { + display: flex; + justify-content: center; + margin-bottom: 20px; +} + +.score-container, .best-container { + background: #bbada0; + padding: 10px; + border-radius: 5px; + margin: 0 10px; + min-width: 100px; + text-align: center; +} + +.score-header, .best-header { + color: #eee4da; + font-size: 18px; + margin-bottom: 5px; +} + +.game-container { + max-width: 500px; + margin: 0 auto 20px; + background: #bbada0; + padding: 15px; + border-radius: 10px; + position: relative; +} + +.grid-row { + display: flex; +} + +.grid-cell { + background: #cdc1b4; + width: 100px; + height: 100px; + margin: 5px; + display: flex; + justify-content: center; + align-items: center; + font-size: 35px; + font-weight: bold; + color: #776e65; + border-radius: 3px; +} + +/* Dynamic classes for different number cells */ +.number-cell-2 { + background: #eee4da; +} + +.number-cell-4 { + background: #ede0c8; +} + +.number-cell-8 { + background: #f2b179; + color: #f9f6f2; +} + +.number-cell-16 { + background: #f59563; + color: #f9f6f2; +} + +.number-cell-32 { + background: #f67c5f; + color: #f9f6f2; +} + +.number-cell-64 { + background: #f65e3b; + color: #f9f6f2; +} + +.number-cell-128 { + background: #edcf72; + color: #f9f6f2; +} + +.number-cell-256 { + background: #edcc61; + color: #f9f6f2; +} + +.number-cell-512 { + background: #edc850; + color: #f9f6f2; +} + +.number-cell-1024 { + background: #edc53f; + color: #f9f6f2; +} + +.number-cell-2048 { + background: #edc22e; + color: #f9f6f2; +} + +/* Larger numbers need smaller font sizes */ +.number-cell-1024, .number-cell-2048 { + font-size: 30px; +} + +button { + background-color: #8f7a66; + color: #f9f6f2; + border: none; + border-radius: 3px; + padding: 10px 20px; + font-size: 18px; + cursor: pointer; + outline: none; +} + +button:hover { + background-color: #9f8b76; +} + +----- storage.js +## storage.js +class Storage { + // 获取最高分 + getBestScore() { + // 尝试从localStorage中获取最高分,如果不存在则默认为0 + const bestScore = localStorage.getItem(\'bestScore\'); + return bestScore ? Number(bestScore) : 0; + } + + // 设置最高分 + setBestScore(score) { + // 将最高分设置到localStorage中 + localStorage.setItem(\'bestScore\', score.toString()); + } +} + + + +## Code to be Reviewed: game.js +```Code +## game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SMALLEST_CONTEXT = """ +## Code to be Reviewed: game.js +```Code +// game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SAMPLE = """ +## Code Review: game.js +1. The code partially implements the requirements. The `Game` class is missing the full implementation of the `move` method, which is crucial for the game\'s functionality. +2. The code logic is not completely correct. The `move` method is not implemented, which means the game cannot process player moves. +3. The existing code follows the "Data structures and interfaces" in terms of class structure but lacks full method implementations. +4. Not all functions are implemented. The `move` method is incomplete and does not handle the logic for moving and merging tiles. +5. All necessary pre-dependencies seem to be imported since the code does not indicate the need for additional imports. +6. The methods from other files (such as `Storage`) are not being used in the provided code snippet, but the class structure suggests that they will be used correctly. + +## Actions +1. Implement the `move` method to handle tile movements and merging. This is a complex task that requires careful consideration of the game\'s rules and logic. Here is a simplified version of how one might begin to implement the `move` method: + ```javascript + move(direction) { + // Simplified logic for moving tiles up + if (direction === \'up\') { + for (let col = 0; col < 4; col++) { + let tiles = this.board.map(row => row[col]).filter(val => val !== 0); + let merged = []; + for (let i = 0; i < tiles.length; i++) { + if (tiles[i] === tiles[i + 1]) { + tiles[i] *= 2; + this.score += tiles[i]; + tiles[i + 1] = 0; + merged.push(i); + } + } + tiles = tiles.filter(val => val !== 0); + while (tiles.length < 4) { + tiles.push(0); + } + for (let row = 0; row < 4; row++) { + this.board[row][col] = tiles[row]; + } + } + } + // Additional logic needed for \'down\', \'left\', \'right\' + // ... + this.addRandomTile(); + } + ``` +2. Integrate the `Storage` class methods to handle the best score. This means updating the `startGame` and `setBestScore` methods to use `Storage` for retrieving and setting the best score: + ```javascript + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = new Storage().getBestScore(); // Retrieve the best score from storage + this.addRandomTile(); + this.addRandomTile(); + } + + setBestScore(score) { + if (score > this.bestScore) { + this.bestScore = score; + new Storage().setBestScore(score); // Set the new best score in storage + } + } + ``` + +## Code Review Result +LBTM + +``` +""" + + +WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, REVIEW_RESULT, NEXT_STEPS]) +WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_FUNCTION]) + + +CR_FOR_MOVE_FUNCTION_BY_3 = """ +The move function implementation provided appears to be well-structured and follows a clear logic for moving and merging tiles in the specified direction. However, there are a few potential improvements that could be made to enhance the code: + +1. Encapsulation: The logic for moving and merging tiles could be encapsulated into smaller, reusable functions to improve readability and maintainability. + +2. Magic Numbers: There are some magic numbers (e.g., 4, 3) used in the loops that could be replaced with named constants for improved readability and easier maintenance. + +3. Comments: Adding comments to explain the logic and purpose of each section of the code can improve understanding for future developers who may need to work on or maintain the code. + +4. Error Handling: It's important to consider error handling for unexpected input or edge cases to ensure the function behaves as expected in all scenarios. + +Overall, the code could benefit from refactoring to improve readability, maintainability, and extensibility. If you would like, I can provide a refactored version of the move function that addresses these considerations. +""" + + +class WriteCodeAN(Action): + """Write a code review for the context.""" + + async def run(self, context): + self.llm.system_prompt = "You are an outstanding engineer and can implement any code" + return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json") + + +async def main(): + await WriteCodeAN().run(CODE_REVIEW_SMALLEST_CONTEXT) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/notebook_dir/metagpt_yusin/actions/write_code_plan_and_change_an.py b/notebook_dir/metagpt_yusin/actions/write_code_plan_and_change_an.py new file mode 100644 index 0000000000000000000000000000000000000000..12789149b07c23ea66172922253f8a0c54909e29 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_code_plan_and_change_an.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mannaandpoem +@File : write_code_plan_and_change_an.py +""" +import os +from typing import List + +from pydantic import Field + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.actions.action_node import ActionNode +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import CodePlanAndChangeContext + +DEVELOPMENT_PLAN = ActionNode( + key="Development Plan", + expected_type=List[str], + instruction="Develop a comprehensive and step-by-step incremental development plan, providing the detail " + "changes to be implemented at each step based on the order of 'Task List'", + example=[ + "Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, ...", + "Update the existing codebase in main.py to incorporate new API endpoints for subtraction, ...", + ], +) + +INCREMENTAL_CHANGE = ActionNode( + key="Incremental Change", + expected_type=List[str], + instruction="Write Incremental Change by making a code draft that how to implement incremental development " + "including detailed steps based on the context. Note: Track incremental changes using the marks `+` and `-` to " + "indicate additions and deletions, and ensure compliance with the output format of `git diff`", + example=[ + '''```diff +--- Old/calculator.py ++++ New/calculator.py + +class Calculator: + self.result = number1 + number2 + return self.result + +- def sub(self, number1, number2) -> float: ++ def subtract(self, number1: float, number2: float) -> float: ++ """ ++ Subtracts the second number from the first and returns the result. ++ ++ Args: ++ number1 (float): The number to be subtracted from. ++ number2 (float): The number to subtract. ++ ++ Returns: ++ float: The difference of number1 and number2. ++ """ ++ self.result = number1 - number2 ++ return self.result ++ + def multiply(self, number1: float, number2: float) -> float: +- pass ++ """ ++ Multiplies two numbers and returns the result. ++ ++ Args: ++ number1 (float): The first number to multiply. ++ number2 (float): The second number to multiply. ++ ++ Returns: ++ float: The product of number1 and number2. ++ """ ++ self.result = number1 * number2 ++ return self.result ++ + def divide(self, number1: float, number2: float) -> float: +- pass ++ """ ++ ValueError: If the second number is zero. ++ """ ++ if number2 == 0: ++ raise ValueError('Cannot divide by zero') ++ self.result = number1 / number2 ++ return self.result ++ +- def reset_result(self): ++ def clear(self): ++ if self.result != 0.0: ++ print("Result is not zero, clearing...") ++ else: ++ print("Result is already zero, no need to clear.") ++ + self.result = 0.0 +```''', + """```diff +--- Old/main.py ++++ New/main.py + +def add_numbers(): + result = calculator.add_numbers(num1, num2) + return jsonify({'result': result}), 200 + +-# TODO: Implement subtraction, multiplication, and division operations ++@app.route('/subtract_numbers', methods=['POST']) ++def subtract_numbers(): ++ data = request.get_json() ++ num1 = data.get('num1', 0) ++ num2 = data.get('num2', 0) ++ result = calculator.subtract_numbers(num1, num2) ++ return jsonify({'result': result}), 200 ++ ++@app.route('/multiply_numbers', methods=['POST']) ++def multiply_numbers(): ++ data = request.get_json() ++ num1 = data.get('num1', 0) ++ num2 = data.get('num2', 0) ++ try: ++ result = calculator.divide_numbers(num1, num2) ++ except ValueError as e: ++ return jsonify({'error': str(e)}), 400 ++ return jsonify({'result': result}), 200 ++ + if __name__ == '__main__': + app.run() +```""", + ], +) + +CODE_PLAN_AND_CHANGE_CONTEXT = """ +## User New Requirements +{requirement} + +## Issue +{issue} + +## PRD +{prd} + +## Design +{design} + +## Task +{task} + +## Legacy Code +{code} +""" + +REFINED_TEMPLATE = """ +NOTICE +Role: You are a professional engineer; The main goal is to complete incremental development by combining legacy code and plan and Incremental Change, ensuring the integration of new features. + +# Context +## User New Requirements +{user_requirement} + +## Code Plan And Change +{code_plan_and_change} + +## Design +{design} + +## Task +{task} + +## Legacy Code +```Code +{code} +``` + +## Debug logs +```text +{logs} + +{summary_log} +``` + +## Bug Feedback logs +```text +{feedback} +``` + +# Format example +## Code: {filename} +```python +## {filename} +... +``` + +# Instruction: Based on the context, follow "Format example", write or rewrite code. +## Write/Rewrite Code: Only write one file {filename}, write or rewrite complete code using triple quotes based on the following attentions and context. +1. Only One file: do your best to implement THIS ONLY ONE FILE. +2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets. +3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import. +4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design. +5. Follow Code Plan And Change: If there is any "Incremental Change" that is marked by the git diff format with '+' and '-' symbols, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the "Development Plan". +6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE. +7. Before using a external variable/module, make sure you import it first. +8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO. +9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code. +""" + +CODE_PLAN_AND_CHANGE = [DEVELOPMENT_PLAN, INCREMENTAL_CHANGE] + +WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", CODE_PLAN_AND_CHANGE) + + +class WriteCodePlanAndChange(Action): + name: str = "WriteCodePlanAndChange" + i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext) + + async def run(self, *args, **kwargs): + self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to " + "meticulously craft comprehensive incremental development plan and deliver detailed incremental change" + prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename) + design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename) + task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename) + context = CODE_PLAN_AND_CHANGE_CONTEXT.format( + requirement=f"```text\n{self.i_context.requirement}\n```", + issue=f"```text\n{self.i_context.issue}\n```", + prd=prd_doc.content, + design=design_doc.content, + task=task_doc.content, + code=await self.get_old_codes(), + ) + logger.info("Writing code plan and change..") + return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json") + + async def get_old_codes(self) -> str: + self.repo.old_workspace = self.repo.git_repo.workdir / os.path.basename(self.config.project_path) + old_file_repo = self.repo.git_repo.new_file_repository(relative_path=self.repo.old_workspace) + old_codes = await old_file_repo.get_all() + codes = [f"----- {code.filename}\n```{code.content}```" for code in old_codes] + return "\n".join(codes) diff --git a/notebook_dir/metagpt_yusin/actions/write_code_review.py b/notebook_dir/metagpt_yusin/actions/write_code_review.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb4f58106c5f8800a88ebb087396d7bcad07d7c --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_code_review.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_code_review.py +@Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the + WriteCode object, rather than passing them in when calling the run function. +""" + +from pydantic import Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt_yusin.actions import WriteCode +from metagpt_yusin.actions.action import Action +from metagpt_yusin.const import REQUIREMENT_FILENAME +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import CodingContext +from metagpt_yusin.utils.common import CodeParser + +PROMPT_TEMPLATE = """ +# System +Role: You are a professional software engineer, and your main task is to review and revise the code. You need to ensure that the code conforms to the google-style standards, is elegantly designed and modularized, easy to read and maintain. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. +ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example". + +# Context +{context} + +## Code to be Reviewed: {filename} +```Code +{code} +``` +""" + +EXAMPLE_AND_INSTRUCTION = """ + +{format_example} + + +# Instruction: Based on the actual code situation, follow one of the "Format example". Return only 1 file under review. + +## Code Review: Ordered List. Based on the "Code to be Reviewed", provide key, clear, concise, and specific answer. If any answer is no, explain how to fix it step by step. +1. Is the code implemented as per the requirements? If not, how to achieve it? Analyse it step by step. +2. Is the code logic completely correct? If there are errors, please indicate how to correct them. +3. Does the existing code follow the "Data structures and interfaces"? +4. Are all functions implemented? If there is no implementation, please indicate how to achieve it step by step. +5. Have all necessary pre-dependencies been imported? If not, indicate which ones need to be imported +6. Are methods from other files being reused correctly? + +## Actions: Ordered List. Things that should be done after CR, such as implementing class A and function B + +## Code Review Result: str. If the code doesn't have bugs, we don't need to rewrite it, so answer LGTM and stop. ONLY ANSWER LGTM/LBTM. +LGTM/LBTM + +""" + +FORMAT_EXAMPLE = """ +# Format example 1 +## Code Review: {filename} +1. No, we should fix the logic of class A due to ... +2. ... +3. ... +4. No, function B is not implemented, ... +5. ... +6. ... + +## Actions +1. Fix the `handle_events` method to update the game state only if a move is successful. + ```python + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + if event.type == pygame.KEYDOWN: + moved = False + if event.key == pygame.K_UP: + moved = self.game.move('UP') + elif event.key == pygame.K_DOWN: + moved = self.game.move('DOWN') + elif event.key == pygame.K_LEFT: + moved = self.game.move('LEFT') + elif event.key == pygame.K_RIGHT: + moved = self.game.move('RIGHT') + if moved: + # Update the game state only if a move was successful + self.render() + return True + ``` +2. Implement function B + +## Code Review Result +LBTM + +# Format example 2 +## Code Review: {filename} +1. Yes. +2. Yes. +3. Yes. +4. Yes. +5. Yes. +6. Yes. + +## Actions +pass + +## Code Review Result +LGTM +""" + +REWRITE_CODE_TEMPLATE = """ +# Instruction: rewrite code based on the Code Review and Actions +## Rewrite Code: CodeBlock. If it still has some bugs, rewrite {filename} with triple quotes. Do your utmost to optimize THIS SINGLE FILE. Return all completed codes and prohibit the return of unfinished codes. +```Code +## {filename} +... +``` +""" + + +class WriteCodeReview(Action): + name: str = "WriteCodeReview" + i_context: CodingContext = Field(default_factory=CodingContext) + + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) + async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): + cr_rsp = await self._aask(context_prompt + cr_prompt) + result = CodeParser.parse_block("Code Review Result", cr_rsp) + if "LGTM" in result: + return result, None + + # if LBTM, rewrite code + rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" + code_rsp = await self._aask(rewrite_prompt) + code = CodeParser.parse_code(block="", text=code_rsp) + return result, code + + async def run(self, *args, **kwargs) -> CodingContext: + iterative_code = self.i_context.code_doc.content + k = self.context.config.code_review_k_times or 1 + + for i in range(k): + format_example = FORMAT_EXAMPLE.format(filename=self.i_context.code_doc.filename) + task_content = self.i_context.task_doc.content if self.i_context.task_doc else "" + code_context = await WriteCode.get_codes( + self.i_context.task_doc, + exclude=self.i_context.filename, + project_repo=self.repo.with_src_path(self.context.src_workspace), + use_inc=self.config.inc, + ) + + ctx_list = [ + "## System Design\n" + str(self.i_context.design_doc) + "\n", + "## Task\n" + task_content + "\n", + "## Code Files\n" + code_context + "\n", + ] + if self.config.inc: + requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME) + insert_ctx_list = [ + "## User New Requirements\n" + str(requirement_doc) + "\n", + "## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n", + ] + ctx_list = insert_ctx_list + ctx_list + + context_prompt = PROMPT_TEMPLATE.format( + context="\n".join(ctx_list), + code=iterative_code, + filename=self.i_context.code_doc.filename, + ) + cr_prompt = EXAMPLE_AND_INSTRUCTION.format( + format_example=format_example, + ) + len1 = len(iterative_code) if iterative_code else 0 + len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0 + logger.info( + f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, " + f"len(self.i_context.code_doc.content)={len2}" + ) + result, rewrited_code = await self.write_code_review_and_rewrite( + context_prompt, cr_prompt, self.i_context.code_doc.filename + ) + if "LBTM" in result: + iterative_code = rewrited_code + elif "LGTM" in result: + self.i_context.code_doc.content = iterative_code + return self.i_context + # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) + # self._save(context, filename, code) + # 如果rewrited_code是None(原code perfect),那么直接返回code + self.i_context.code_doc.content = iterative_code + return self.i_context diff --git a/notebook_dir/metagpt_yusin/actions/write_docstring.py b/notebook_dir/metagpt_yusin/actions/write_docstring.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc50addcc4b2355d5a87617b76846372c23d540 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_docstring.py @@ -0,0 +1,218 @@ +"""Code Docstring Generator. + +This script provides a tool to automatically generate docstrings for Python code. It uses the specified style to create +docstrings for the given code and system text. + +Usage: + python3 -m metagpt_yusin.actions.write_docstring [--overwrite] [--style=] + +Arguments: + filename The path to the Python file for which you want to generate docstrings. + +Options: + --overwrite If specified, overwrite the original file with the code containing docstrings. + --style= Specify the style of the generated docstrings. + Valid values: 'google', 'numpy', or 'sphinx'. + Default: 'google' + +Example: + python3 -m metagpt_yusin.actions.write_docstring ./metagpt_yusin/software_company.py --overwrite False --style=numpy + +This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using +the specified docstring style and adds them to the code. +""" +from __future__ import annotations + +import ast +from pathlib import Path +from typing import Literal, Optional + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.utils.common import OutputParser, aread, awrite +from metagpt_yusin.utils.pycst import merge_docstring + +PYTHON_DOCSTRING_SYSTEM = """### Requirements +1. Add docstrings to the given code following the {style} style. +2. Replace the function body with an Ellipsis object(...) to reduce output. +3. If the types are already annotated, there is no need to include them in the docstring. +4. Extract only class, function or the docstrings for the module parts from the given Python code, avoiding any other text. + +### Input Example +```python +def function_with_pep484_type_annotations(param1: int) -> bool: + return isinstance(param1, int) + +class ExampleError(Exception): + def __init__(self, msg: str): + self.msg = msg +``` + +### Output Example +```python +{example} +``` +""" + +# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html + +PYTHON_DOCSTRING_EXAMPLE_GOOGLE = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """Example function with PEP 484 type annotations. + + Extended description of function. + + Args: + param1: The first parameter. + + Returns: + The return value. True for success, False otherwise. + """ + ... + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + Args: + msg: Human readable string describing the exception. + + Attributes: + msg: Human readable string describing the exception. + """ + ... +''' + +PYTHON_DOCSTRING_EXAMPLE_NUMPY = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """ + Example function with PEP 484 type annotations. + + Extended description of function. + + Parameters + ---------- + param1 + The first parameter. + + Returns + ------- + bool + The return value. True for success, False otherwise. + """ + ... + +class ExampleError(Exception): + """ + Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + Parameters + ---------- + msg + Human readable string describing the exception. + + Attributes + ---------- + msg + Human readable string describing the exception. + """ + ... +''' + +PYTHON_DOCSTRING_EXAMPLE_SPHINX = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """Example function with PEP 484 type annotations. + + Extended description of function. + + :param param1: The first parameter. + :type param1: int + + :return: The return value. True for success, False otherwise. + :rtype: bool + """ + ... + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + :param msg: Human-readable string describing the exception. + :type msg: str + """ + ... +''' + +_python_docstring_style = { + "google": PYTHON_DOCSTRING_EXAMPLE_GOOGLE.strip(), + "numpy": PYTHON_DOCSTRING_EXAMPLE_NUMPY.strip(), + "sphinx": PYTHON_DOCSTRING_EXAMPLE_SPHINX.strip(), +} + + +class WriteDocstring(Action): + """This class is used to write docstrings for code. + + Attributes: + desc: A string describing the action. + """ + + desc: str = "Write docstring for code." + i_context: Optional[str] = None + + async def run( + self, + code: str, + system_text: str = PYTHON_DOCSTRING_SYSTEM, + style: Literal["google", "numpy", "sphinx"] = "google", + ) -> str: + """Writes docstrings for the given code and system text in the specified style. + + Args: + code: A string of Python code. + system_text: A string of system text. + style: A string specifying the style of the docstring. Can be 'google', 'numpy', or 'sphinx'. + + Returns: + The Python code with docstrings added. + """ + system_text = system_text.format(style=style, example=_python_docstring_style[style]) + simplified_code = _simplify_python_code(code) + documented_code = await self._aask(f"```python\n{simplified_code}\n```", [system_text]) + documented_code = OutputParser.parse_python_code(documented_code) + return merge_docstring(code, documented_code) + + @staticmethod + async def write_docstring( + filename: str | Path, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google" + ) -> str: + data = await aread(str(filename)) + code = await WriteDocstring().run(data, style=style) + if overwrite: + await awrite(filename, code) + return code + + +def _simplify_python_code(code: str) -> None: + """Simplifies the given Python code by removing expressions and the last if statement. + + Args: + code: A string of Python code. + + Returns: + The simplified Python code. + """ + code_tree = ast.parse(code) + code_tree.body = [i for i in code_tree.body if not isinstance(i, ast.Expr)] + if isinstance(code_tree.body[-1], ast.If): + code_tree.body.pop() + return ast.unparse(code_tree) + + +if __name__ == "__main__": + import fire + + fire.Fire(WriteDocstring.write_docstring) diff --git a/notebook_dir/metagpt_yusin/actions/write_prd.py b/notebook_dir/metagpt_yusin/actions/write_prd.py new file mode 100644 index 0000000000000000000000000000000000000000..9305cba3e6f04a0f4a5691e74b2465bc6243c847 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_prd.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_prd.py +@Modified By: mashenquan, 2023/11/27. + 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. + 2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality. + 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. +@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +from metagpt_yusin.actions import Action, ActionOutput +from metagpt_yusin.actions.action_node import ActionNode +from metagpt_yusin.actions.fix_bug import FixBug +from metagpt_yusin.actions.write_prd_an import ( + COMPETITIVE_QUADRANT_CHART, + PROJECT_NAME, + REFINED_PRD_NODE, + WP_IS_RELATIVE_NODE, + WP_ISSUE_TYPE_NODE, + WRITE_PRD_NODE, +) +from metagpt_yusin.const import ( + BUGFIX_FILENAME, + COMPETITIVE_ANALYSIS_FILE_REPO, + REQUIREMENT_FILENAME, +) +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import BugFixContext, Document, Documents, Message +from metagpt_yusin.utils.common import CodeParser +from metagpt_yusin.utils.file_repository import FileRepository +from metagpt_yusin.utils.mermaid import mermaid_to_file + +CONTEXT_TEMPLATE = """ +### Project Name +{project_name} + +### Original Requirements +{requirements} + +### Search Information +- +""" + +NEW_REQ_TEMPLATE = """ +### Legacy Content +{old_prd} + +### New Requirements +{requirements} +""" + + +class WritePRD(Action): + """WritePRD deal with the following situations: + 1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated. + 2. New requirement: If the requirement is a new requirement, the PRD document will be generated. + 3. Requirement update: If the requirement is an update, the PRD document will be updated. + """ + + async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message: + """Run the action.""" + req: Document = await self.repo.requirement + docs: list[Document] = await self.repo.docs.prd.get_all() + if not req: + raise FileNotFoundError("No requirement document found.") + + if await self._is_bugfix(req.content): + logger.info(f"Bugfix detected: {req.content}") + return await self._handle_bugfix(req) + # remove bugfix file from last round in case of conflict + await self.repo.docs.delete(filename=BUGFIX_FILENAME) + + # if requirement is related to other documents, update them, otherwise create a new one + if related_docs := await self.get_related_docs(req, docs): + logger.info(f"Requirement update detected: {req.content}") + return await self._handle_requirement_update(req, related_docs) + else: + logger.info(f"New requirement detected: {req.content}") + return await self._handle_new_requirement(req) + + async def _handle_bugfix(self, req: Document) -> Message: + # ... bugfix logic ... + await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content) + await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="") + bug_fix = BugFixContext(filename=BUGFIX_FILENAME) + return Message( + content=bug_fix.model_dump_json(), + instruct_content=bug_fix, + role="", + cause_by=FixBug, + sent_from=self, + send_to="Alex", # the name of Engineer + ) + + async def _handle_new_requirement(self, req: Document) -> ActionOutput: + """handle new requirement""" + project_name = self.project_name + context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name) + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema + await self._rename_workspace(node) + new_prd_doc = await self.repo.docs.prd.save( + filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json() + ) + await self._save_competitive_analysis(new_prd_doc) + await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + return Documents.from_iterable(documents=[new_prd_doc]).to_action_output() + + async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput: + # ... requirement update logic ... + for doc in related_docs: + await self._update_prd(req, doc) + return Documents.from_iterable(documents=related_docs).to_action_output() + + async def _is_bugfix(self, context: str) -> bool: + if not self.repo.code_files_exists(): + return False + node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm) + return node.get("issue_type") == "BUG" + + async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]: + """get the related documents""" + # refine: use gather to speed up + return [i for i in docs if await self._is_related(req, i)] + + async def _is_related(self, req: Document, old_prd: Document) -> bool: + context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content) + node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) + return node.get("is_relative") == "YES" + + async def _merge(self, req: Document, related_doc: Document) -> Document: + if not self.project_name: + self.project_name = Path(self.project_path).name + prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content) + node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema) + related_doc.content = node.instruct_content.model_dump_json() + await self._rename_workspace(node) + return related_doc + + async def _update_prd(self, req: Document, prd_doc: Document) -> Document: + new_prd_doc: Document = await self._merge(req, prd_doc) + await self.repo.docs.prd.save_doc(doc=new_prd_doc) + await self._save_competitive_analysis(new_prd_doc) + await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + return new_prd_doc + + async def _save_competitive_analysis(self, prd_doc: Document): + m = json.loads(prd_doc.content) + quadrant_chart = m.get(COMPETITIVE_QUADRANT_CHART.key) + if not quadrant_chart: + return + pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem + pathname.parent.mkdir(parents=True, exist_ok=True) + await mermaid_to_file(self.config.mermaid.engine, quadrant_chart, pathname) + + async def _rename_workspace(self, prd): + if not self.project_name: + if isinstance(prd, (ActionOutput, ActionNode)): + ws_name = prd.instruct_content.model_dump()["Project Name"] + else: + ws_name = CodeParser.parse_str(block="Project Name", text=prd) + if ws_name: + self.project_name = ws_name + self.repo.git_repo.rename_root(self.project_name) diff --git a/notebook_dir/metagpt_yusin/actions/write_prd_an.py b/notebook_dir/metagpt_yusin/actions/write_prd_an.py new file mode 100644 index 0000000000000000000000000000000000000000..9ddecd469187ed3a7f33c3c49678158d5132dde1 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_prd_an.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/14 11:40 +@Author : alexanderwu +@File : write_prd_an.py +""" +from typing import List + +from metagpt_yusin.actions.action_node import ActionNode + +LANGUAGE = ActionNode( + key="Language", + expected_type=str, + instruction="Provide the language used in the project, typically matching the user's requirement language.", + example="en_us", +) + +PROGRAMMING_LANGUAGE = ActionNode( + key="Programming Language", + expected_type=str, + instruction="Python/JavaScript or other mainstream programming language.", + example="Python", +) + +ORIGINAL_REQUIREMENTS = ActionNode( + key="Original Requirements", + expected_type=str, + instruction="Place the original user's requirements here.", + example="Create a 2048 game", +) + +REFINED_REQUIREMENTS = ActionNode( + key="Refined Requirements", + expected_type=str, + instruction="Place the New user's original requirements here.", + example="Create a 2048 game with a new feature that ...", +) + +PROJECT_NAME = ActionNode( + key="Project Name", + expected_type=str, + instruction='According to the content of "Original Requirements," name the project using snake case style , ' + "like 'game_2048' or 'simple_crm.", + example="game_2048", +) + +PRODUCT_GOALS = ActionNode( + key="Product Goals", + expected_type=List[str], + instruction="Provide up to three clear, orthogonal product goals.", + example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"], +) + +REFINED_PRODUCT_GOALS = ActionNode( + key="Refined Product Goals", + expected_type=List[str], + instruction="Update and expand the original product goals to reflect the evolving needs due to incremental " + "development. Ensure that the refined goals align with the current project direction and contribute to its success.", + example=[ + "Enhance user engagement through new features", + "Optimize performance for scalability", + "Integrate innovative UI enhancements", + ], +) + +USER_STORIES = ActionNode( + key="User Stories", + expected_type=List[str], + instruction="Provide up to 3 to 5 scenario-based user stories.", + example=[ + "As a player, I want to be able to choose difficulty levels", + "As a player, I want to see my score after each game", + "As a player, I want to get restart button when I lose", + "As a player, I want to see beautiful UI that make me feel good", + "As a player, I want to play game via mobile phone", + ], +) + +REFINED_USER_STORIES = ActionNode( + key="Refined User Stories", + expected_type=List[str], + instruction="Update and expand the original scenario-based user stories to reflect the evolving needs due to " + "incremental development. Ensure that the refined user stories capture incremental features and improvements. ", + example=[ + "As a player, I want to choose difficulty levels to challenge my skills", + "As a player, I want a visually appealing score display after each game for a better gaming experience", + "As a player, I want a convenient restart button displayed when I lose to quickly start a new game", + "As a player, I want an enhanced and aesthetically pleasing UI to elevate the overall gaming experience", + "As a player, I want the ability to play the game seamlessly on my mobile phone for on-the-go entertainment", + ], +) + +COMPETITIVE_ANALYSIS = ActionNode( + key="Competitive Analysis", + expected_type=List[str], + instruction="Provide 5 to 7 competitive products.", + example=[ + "2048 Game A: Simple interface, lacks responsive features", + "play2048.co: Beautiful and responsive UI with my best score shown", + "2048game.com: Responsive UI with my best score shown, but many ads", + ], +) + +COMPETITIVE_QUADRANT_CHART = ActionNode( + key="Competitive Quadrant Chart", + expected_type=str, + instruction="Use mermaid quadrantChart syntax. Distribute scores evenly between 0 and 1", + example="""quadrantChart + title "Reach and engagement of campaigns" + x-axis "Low Reach" --> "High Reach" + y-axis "Low Engagement" --> "High Engagement" + quadrant-1 "We should expand" + quadrant-2 "Need to promote" + quadrant-3 "Re-evaluate" + quadrant-4 "May be improved" + "Campaign A": [0.3, 0.6] + "Campaign B": [0.45, 0.23] + "Campaign C": [0.57, 0.69] + "Campaign D": [0.78, 0.34] + "Campaign E": [0.40, 0.34] + "Campaign F": [0.35, 0.78] + "Our Target Product": [0.5, 0.6]""", +) + +REQUIREMENT_ANALYSIS = ActionNode( + key="Requirement Analysis", + expected_type=str, + instruction="Provide a detailed analysis of the requirements.", + example="", +) + +REFINED_REQUIREMENT_ANALYSIS = ActionNode( + key="Refined Requirement Analysis", + expected_type=List[str], + instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project " + "due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements " + "required for the refined project scope.", + example=["Require add ...", "Require modify ..."], +) + +REQUIREMENT_POOL = ActionNode( + key="Requirement Pool", + expected_type=List[List[str]], + instruction="List down the top-5 requirements with their priority (P0, P1, P2).", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], +) + +REFINED_REQUIREMENT_POOL = ActionNode( + key="Refined Requirement Pool", + expected_type=List[List[str]], + instruction="List down the top 5 to 7 requirements with their priority (P0, P1, P2). " + "Cover both legacy content and incremental content. Retain content unrelated to incremental development", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], +) + +UI_DESIGN_DRAFT = ActionNode( + key="UI Design draft", + expected_type=str, + instruction="Provide a simple description of UI elements, functions, style, and layout.", + example="Basic function description with a simple style and layout.", +) + +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any aspects of the project that are unclear and try to clarify them.", + example="", +) + +ISSUE_TYPE = ActionNode( + key="issue_type", + expected_type=str, + instruction="Answer BUG/REQUIREMENT. If it is a bugfix, answer BUG, otherwise answer Requirement", + example="BUG", +) + +IS_RELATIVE = ActionNode( + key="is_relative", + expected_type=str, + instruction="Answer YES/NO. If the requirement is related to the old PRD, answer YES, otherwise NO", + example="YES", +) + +REASON = ActionNode( + key="reason", expected_type=str, instruction="Explain the reasoning process from question to answer", example="..." +) + + +NODES = [ + LANGUAGE, + PROGRAMMING_LANGUAGE, + ORIGINAL_REQUIREMENTS, + PROJECT_NAME, + PRODUCT_GOALS, + USER_STORIES, + COMPETITIVE_ANALYSIS, + COMPETITIVE_QUADRANT_CHART, + REQUIREMENT_ANALYSIS, + REQUIREMENT_POOL, + UI_DESIGN_DRAFT, + ANYTHING_UNCLEAR, +] + +REFINED_NODES = [ + LANGUAGE, + PROGRAMMING_LANGUAGE, + REFINED_REQUIREMENTS, + PROJECT_NAME, + REFINED_PRODUCT_GOALS, + REFINED_USER_STORIES, + COMPETITIVE_ANALYSIS, + COMPETITIVE_QUADRANT_CHART, + REFINED_REQUIREMENT_ANALYSIS, + REFINED_REQUIREMENT_POOL, + UI_DESIGN_DRAFT, + ANYTHING_UNCLEAR, +] + +WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES) +REFINED_PRD_NODE = ActionNode.from_children("RefinedPRD", REFINED_NODES) +WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON]) +WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON]) diff --git a/notebook_dir/metagpt_yusin/actions/write_prd_review.py b/notebook_dir/metagpt_yusin/actions/write_prd_review.py new file mode 100644 index 0000000000000000000000000000000000000000..e33282c58239d5bcf5f12400cef0287d19e045b0 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_prd_review.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_prd_review.py +""" + +from typing import Optional + +from metagpt_yusin.actions.action import Action + + +class WritePRDReview(Action): + name: str = "" + i_context: Optional[str] = None + + prd: Optional[str] = None + desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" + prd_review_prompt_template: str = """ +Given the following Product Requirement Document (PRD): +{prd} + +As a project manager, please review it and provide your feedback and suggestions. +""" + + async def run(self, prd): + self.prd = prd + prompt = self.prd_review_prompt_template.format(prd=self.prd) + review = await self._aask(prompt) + return review diff --git a/notebook_dir/metagpt_yusin/actions/write_review.py b/notebook_dir/metagpt_yusin/actions/write_review.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c9ffc11c6c66650cd88b3101198fd45e4024a0 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_review.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +from typing import List + +from metagpt_yusin.actions import Action +from metagpt_yusin.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced Reviewer and review the given output. Ask a series of critical questions, " + "concisely and clearly, to help the writer improve their work.", + example=[ + "This is a good PRD, but I think it can be improved by adding more details.", + ], +) + +LGTM = ActionNode( + key="LGTM", + expected_type=str, + instruction="LGTM/LBTM. If the output is good enough, give a LGTM (Looks Good To Me) to the writer, " + "else LBTM (Looks Bad To Me).", + example="LGTM", +) + +WRITE_REVIEW_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM]) + + +class WriteReview(Action): + """Write a review for the given context.""" + + name: str = "WriteReview" + + async def run(self, context): + return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/notebook_dir/metagpt_yusin/actions/write_teaching_plan.py b/notebook_dir/metagpt_yusin/actions/write_teaching_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..75f21d485edd32e360dcee8c64b42b0e1d0324fb --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_teaching_plan.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/27 +@Author : mashenquan +@File : write_teaching_plan.py +""" +from typing import Optional + +from metagpt_yusin.actions import Action +from metagpt_yusin.context import Context +from metagpt_yusin.logs import logger + + +class WriteTeachingPlanPart(Action): + """Write Teaching Plan Part""" + + i_context: Optional[str] = None + topic: str = "" + language: str = "Chinese" + rsp: Optional[str] = None + + async def run(self, with_message=None, **kwargs): + statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, []) + statements = [] + for p in statement_patterns: + s = self.format_value(p, context=self.context) + statements.append(s) + formatter = ( + TeachingPlanBlock.PROMPT_TITLE_TEMPLATE + if self.topic == TeachingPlanBlock.COURSE_TITLE + else TeachingPlanBlock.PROMPT_TEMPLATE + ) + prompt = formatter.format( + formation=TeachingPlanBlock.FORMATION, + role=self.prefix, + statements="\n".join(statements), + lesson=self.i_context, + topic=self.topic, + language=self.language, + ) + + logger.debug(prompt) + rsp = await self._aask(prompt=prompt) + logger.debug(rsp) + self._set_result(rsp) + return self.rsp + + def _set_result(self, rsp): + if TeachingPlanBlock.DATA_BEGIN_TAG in rsp: + ix = rsp.index(TeachingPlanBlock.DATA_BEGIN_TAG) + rsp = rsp[ix + len(TeachingPlanBlock.DATA_BEGIN_TAG) :] + if TeachingPlanBlock.DATA_END_TAG in rsp: + ix = rsp.index(TeachingPlanBlock.DATA_END_TAG) + rsp = rsp[0:ix] + self.rsp = rsp.strip() + if self.topic != TeachingPlanBlock.COURSE_TITLE: + return + if "#" not in self.rsp or self.rsp.index("#") != 0: + self.rsp = "# " + self.rsp + + def __str__(self): + """Return `topic` value when str()""" + return self.topic + + def __repr__(self): + """Show `topic` value when debug""" + return self.topic + + @staticmethod + def format_value(value, context: Context): + """Fill parameters inside `value` with `options`.""" + if not isinstance(value, str): + return value + if "{" not in value: + return value + + options = context.config.model_dump() + for k, v in context.kwargs: + options[k] = v # None value is allowed to override and disable the value from config. + opts = {k: v for k, v in options.items() if v is not None} + try: + return value.format(**opts) + except KeyError as e: + logger.warning(f"Parameter is missing:{e}") + + for k, v in opts.items(): + value = value.replace("{" + f"{k}" + "}", str(v)) + return value + + +class TeachingPlanBlock: + FORMATION = ( + '"Capacity and role" defines the role you are currently playing;\n' + '\t"[LESSON_BEGIN]" and "[LESSON_END]" tags enclose the content of textbook;\n' + '\t"Statement" defines the work detail you need to complete at this stage;\n' + '\t"Answer options" defines the format requirements for your responses;\n' + '\t"Constraint" defines the conditions that your responses must comply with.' + ) + + COURSE_TITLE = "Title" + TOPICS = [ + COURSE_TITLE, + "Teaching Hours", + "Teaching Objectives", + "Teaching Content", + "Teaching Methods and Strategies", + "Learning Activities", + "Teaching Time Allocation", + "Assessment and Feedback", + "Teaching Summary and Improvement", + "Vocabulary Cloze", + "Choice Questions", + "Grammar Questions", + "Translation Questions", + ] + + TOPIC_STATEMENTS = { + COURSE_TITLE: [ + "Statement: Find and return the title of the lesson only in markdown first-level header format, " + "without anything else." + ], + "Teaching Content": [ + 'Statement: "Teaching Content" must include vocabulary, analysis, and examples of various grammar ' + "structures that appear in the textbook, as well as the listening materials and key points.", + 'Statement: "Teaching Content" must include more examples.', + ], + "Teaching Time Allocation": [ + 'Statement: "Teaching Time Allocation" must include how much time is allocated to each ' + "part of the textbook content." + ], + "Teaching Methods and Strategies": [ + 'Statement: "Teaching Methods and Strategies" must include teaching focus, difficulties, materials, ' + "procedures, in detail." + ], + "Vocabulary Cloze": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create vocabulary cloze. The cloze should include 10 {language} questions with {teaching_language} " + "answers, and it should also include 10 {teaching_language} questions with {language} answers. " + "The key-related vocabulary and phrases in the textbook content must all be included in the exercises.", + ], + "Grammar Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create grammar questions. 10 questions." + ], + "Choice Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create choice questions. 10 questions." + ], + "Translation Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create translation questions. The translation should include 10 {language} questions with " + "{teaching_language} answers, and it should also include 10 {teaching_language} questions with " + "{language} answers." + ], + } + + # Teaching plan title + PROMPT_TITLE_TEMPLATE = ( + "Do not refer to the context of the previous conversation records, " + "start the conversation anew.\n\n" + "Formation: {formation}\n\n" + "{statements}\n" + "Constraint: Writing in {language}.\n" + 'Answer options: Encloses the lesson title with "[TEACHING_PLAN_BEGIN]" ' + 'and "[TEACHING_PLAN_END]" tags.\n' + "[LESSON_BEGIN]\n" + "{lesson}\n" + "[LESSON_END]" + ) + + # Teaching plan parts: + PROMPT_TEMPLATE = ( + "Do not refer to the context of the previous conversation records, " + "start the conversation anew.\n\n" + "Formation: {formation}\n\n" + "Capacity and role: {role}\n" + 'Statement: Write the "{topic}" part of teaching plan, ' + 'WITHOUT ANY content unrelated to "{topic}"!!\n' + "{statements}\n" + 'Answer options: Enclose the teaching plan content with "[TEACHING_PLAN_BEGIN]" ' + 'and "[TEACHING_PLAN_END]" tags.\n' + "Answer options: Using proper markdown format from second-level header format.\n" + "Constraint: Writing in {language}.\n" + "[LESSON_BEGIN]\n" + "{lesson}\n" + "[LESSON_END]" + ) + + DATA_BEGIN_TAG = "[TEACHING_PLAN_BEGIN]" + DATA_END_TAG = "[TEACHING_PLAN_END]" diff --git a/notebook_dir/metagpt_yusin/actions/write_test.py b/notebook_dir/metagpt_yusin/actions/write_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e5949d3e1b4622c200ffeceb8195250f99c005 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_test.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 22:12 +@Author : alexanderwu +@File : write_test.py +@Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the + WriteTest object, rather than passing them in when calling the run function. +""" + +from typing import Optional + +from metagpt_yusin.actions.action import Action +from metagpt_yusin.const import TEST_CODES_FILE_REPO +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Document, TestingContext +from metagpt_yusin.utils.common import CodeParser + +PROMPT_TEMPLATE = """ +NOTICE +1. Role: You are a QA engineer; the main goal is to design, develop, and execute PEP8 compliant, well-structured, maintainable test cases and scripts for Python 3.9. Your focus should be on ensuring the product quality of the entire project through systematic testing. +2. Requirement: Based on the context, develop a comprehensive test suite that adequately covers all relevant aspects of the code file under review. Your test suite will be part of the overall project QA, so please develop complete, robust, and reusable test cases. +3. Attention1: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the test case or script. +4. Attention2: If there are any settings in your tests, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. +5. Attention3: YOU MUST FOLLOW "Data structures and interfaces". DO NOT CHANGE ANY DESIGN. Make sure your tests respect the existing design and ensure its validity. +6. Think before writing: What should be tested and validated in this document? What edge cases could exist? What might fail? +7. CAREFULLY CHECK THAT YOU DON'T MISS ANY NECESSARY TEST CASES/SCRIPTS IN THIS FILE. +Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the test case or script and triple quotes. +----- +## Given the following code, please write appropriate test cases using Python's unittest framework to verify the correctness and robustness of this code: +```python +{code_to_test} +``` +Note that the code to test is at {source_file_path}, we will put your test code at {workspace}/tests/{test_file_name}, and run your test code from {workspace}, +you should correctly import the necessary classes based on these file locations! +## {test_file_name}: Write test code with triple quote. Do your best to implement THIS ONLY ONE FILE. +""" + + +class WriteTest(Action): + name: str = "WriteTest" + i_context: Optional[TestingContext] = None + + async def write_code(self, prompt): + code_rsp = await self._aask(prompt) + + try: + code = CodeParser.parse_code(block="", text=code_rsp) + except Exception: + # Handle the exception if needed + logger.error(f"Can't parse the code: {code_rsp}") + + # Return code_rsp in case of an exception, assuming llm just returns code as it is and doesn't wrap it inside ``` + code = code_rsp + return code + + async def run(self, *args, **kwargs) -> TestingContext: + if not self.i_context.test_doc: + self.i_context.test_doc = Document( + filename="test_" + self.i_context.code_doc.filename, root_path=TEST_CODES_FILE_REPO + ) + fake_root = "/data" + prompt = PROMPT_TEMPLATE.format( + code_to_test=self.i_context.code_doc.content, + test_file_name=self.i_context.test_doc.filename, + source_file_path=fake_root + "/" + self.i_context.code_doc.root_relative_path, + workspace=fake_root, + ) + self.i_context.test_doc.content = await self.write_code(prompt) + return self.i_context diff --git a/notebook_dir/metagpt_yusin/actions/write_tutorial.py b/notebook_dir/metagpt_yusin/actions/write_tutorial.py new file mode 100644 index 0000000000000000000000000000000000000000..8a9176077ed56d138b1e6a02d3a964e1a10630b3 --- /dev/null +++ b/notebook_dir/metagpt_yusin/actions/write_tutorial.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ +""" +@Time : 2023/9/4 15:40:40 +@Author : Stitch-z +@File : tutorial_assistant.py +@Describe : Actions of the tutorial assistant, including writing directories and document content. +""" + +from typing import Dict + +from metagpt_yusin.actions import Action +from metagpt_yusin.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT +from metagpt_yusin.utils.common import OutputParser + + +class WriteDirectory(Action): + """Action class for writing tutorial directories. + + Args: + name: The name of the action. + language: The language to output, default is "Chinese". + """ + + name: str = "WriteDirectory" + language: str = "Chinese" + + async def run(self, topic: str, *args, **kwargs) -> Dict: + """Execute the action to generate a tutorial directory according to the topic. + + Args: + topic: The tutorial topic. + + Returns: + the tutorial directory information, including {"title": "xxx", "directory": [{"dir 1": ["sub dir 1", "sub dir 2"]}]}. + """ + prompt = DIRECTORY_PROMPT.format(topic=topic, language=self.language) + resp = await self._aask(prompt=prompt) + return OutputParser.extract_struct(resp, dict) + + +class WriteContent(Action): + """Action class for writing tutorial content. + + Args: + name: The name of the action. + directory: The content to write. + language: The language to output, default is "Chinese". + """ + + name: str = "WriteContent" + directory: dict = dict() + language: str = "Chinese" + + async def run(self, topic: str, *args, **kwargs) -> str: + """Execute the action to write document content according to the directory and topic. + + Args: + topic: The tutorial topic. + + Returns: + The written tutorial content. + """ + prompt = CONTENT_PROMPT.format(topic=topic, language=self.language, directory=self.directory) + return await self._aask(prompt=prompt) diff --git a/notebook_dir/metagpt_yusin/config2.py b/notebook_dir/metagpt_yusin/config2.py new file mode 100644 index 0000000000000000000000000000000000000000..c7232018472bbbe426b538cbbdd4e3e377fa10b5 --- /dev/null +++ b/notebook_dir/metagpt_yusin/config2.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : alexanderwu +@File : config2.py +""" +import os +from pathlib import Path +from typing import Dict, Iterable, List, Literal, Optional + +from pydantic import BaseModel, model_validator + +from metagpt_yusin.logs import logger +from metagpt_yusin.configs.browser_config import BrowserConfig +from metagpt_yusin.configs.llm_config import LLMConfig, LLMType +from metagpt_yusin.configs.mermaid_config import MermaidConfig +from metagpt_yusin.configs.redis_config import RedisConfig +from metagpt_yusin.configs.s3_config import S3Config +from metagpt_yusin.configs.search_config import SearchConfig +from metagpt_yusin.configs.workspace_config import WorkspaceConfig +from metagpt_yusin.const import CONFIG_ROOT, metagpt_yusin_ROOT +from metagpt_yusin.utils.yaml_model import YamlModel + +# list all varibles in LLMType +#LLMType_dict = LLMType.__dict__ + +class CLIParams(BaseModel): + """CLI parameters""" + + project_path: str = "" + project_name: str = "" + inc: bool = False + reqa_file: str = "" + max_auto_summarize_code: int = 0 + git_reinit: bool = False + + @model_validator(mode="after") + def check_project_path(self): + """Check project_path and project_name""" + if self.project_path: + self.inc = True + self.project_name = self.project_name or Path(self.project_path).name + return self + + +class Config(CLIParams, YamlModel): + """Configurations for metagpt_yusin""" + + # Key Parameters + llm: LLMConfig + + # Global Proxy. Will be used if llm.proxy is not set + proxy: str = "" + + # Tool Parameters + search: SearchConfig = SearchConfig() + browser: BrowserConfig = BrowserConfig() + mermaid: MermaidConfig = MermaidConfig() + + # Storage Parameters + s3: Optional[S3Config] = None + redis: Optional[RedisConfig] = None + + # Misc Parameters + repair_llm_output: bool = False + prompt_schema: Literal["json", "markdown", "raw"] = "json" + workspace: WorkspaceConfig = WorkspaceConfig() + enable_longterm_memory: bool = False + code_review_k_times: int = 2 + + # Will be removed in the future + metagpt_yusin_tti_url: str = "" + language: str = "English" + redis_key: str = "placeholder" + iflytek_app_id: str = "" + iflytek_api_secret: str = "" + iflytek_api_key: str = "" + azure_tts_subscription_key: str = "" + azure_tts_region: str = "" + _extra: dict = dict() # extra config dict + + @classmethod + def from_home(cls, path): + """Load config from ~/.metagpt_yusin/config2.yaml""" + pathname = CONFIG_ROOT / path + if not pathname.exists(): + return None + return Config.from_yaml_file(pathname) + + @classmethod + def default(cls): + """Load default config + - Priority: env < default_config_paths + - Inside default_config_paths, the latter one overwrites the former one + """ + + #default_config_paths: List[Path] = [ + # metagpt_yusin_ROOT / "config/config2.yaml", + # CONFIG_ROOT / "config2.yaml", + #] + + default_config_paths: List[Path] = [ + CONFIG_ROOT / "config2.yaml", + ] + + dicts = [dict(os.environ)] + dicts += [Config.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + config_init = Config(**final) + + + # appended new + if 'api_type' in os.environ: + if os.environ.get('api_type') == 'openai': + config_init.llm.api_type = LLMType.OPENAI + elif os.environ.get('api_type') == 'groq': + config_init.llm.api_type = LLMType.OPENAI + config_init.llm.base_url = 'https://api.groq.com/openai/v1' + elif os.environ.get('api_type') == 'openrouter': + config_init.llm.api_type = LLMType.OPENROUTER + config_init.llm.base_url = 'https://openrouter.ai/api/v1' + else: + logger.debug('The API Type is not supported!!') + else: + logger.debug('Provide your api type!!') + if 'model' in os.environ: + config_init.llm.model = os.environ.get('model') + else: + logger.debug('Provide your model!!') + if 'api_key' in os.environ: + config_init.llm.api_key = os.environ.get('api_key') + else: + logger.debug('Provide your api key!!') + + + return config_init + + @classmethod + def from_llm_config(cls, llm_config: dict): + """user config llm + example: + llm_config = {"api_type": "xxx", "api_key": "xxx", "model": "xxx"} + gpt4 = Config.from_llm_config(llm_config) + A = Role(name="A", profile="Democratic candidate", goal="Win the election", actions=[a1], watch=[a2], config=gpt4) + """ + llm_config = LLMConfig.model_validate(llm_config) + dicts = [dict(os.environ)] + dicts += [{"llm": llm_config}] + final = merge_dict(dicts) + return Config(**final) + + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + + @property + def extra(self): + return self._extra + + @extra.setter + def extra(self, value: dict): + self._extra = value + + def get_openai_llm(self) -> Optional[LLMConfig]: + """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" + if self.llm.api_type == LLMType.OPENAI: + return self.llm + return None + + def get_azure_llm(self) -> Optional[LLMConfig]: + """Get Azure LLMConfig by name. If no Azure, raise Exception""" + if self.llm.api_type == LLMType.AZURE: + return self.llm + return None + + +def merge_dict(dicts: Iterable[Dict]) -> Dict: + """Merge multiple dicts into one, with the latter dict overwriting the former""" + result = {} + for dictionary in dicts: + result.update(dictionary) + return result + + +config = Config.default() diff --git a/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/__init__-checkpoint.py b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e42e6788f240b7df0abbf07410554d66641313ba --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : __init__.py +""" diff --git a/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/browser_config-checkpoint.py b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/browser_config-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5ebda24c8a1e5d0910d90d69b925f349ec6576 --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/browser_config-checkpoint.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : browser_config.py +""" +from typing import Literal + +from metagpt_yusin.tools import WebBrowserEngineType +from metagpt_yusin.utils.yaml_model import YamlModel + + +class BrowserConfig(YamlModel): + """Config for Browser""" + + engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT + browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium" + """If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value + should be either "chrome", "firefox", "edge", or "ie".""" diff --git a/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/config2-checkpoint.yaml b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/config2-checkpoint.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0e8a51d25a8d89005ffcb32932c17b29a522625 --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/config2-checkpoint.yaml @@ -0,0 +1,9 @@ +# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml +# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py +llm: + api_type: 'openai' # or azure / ollama / open_llm etc. Check LLMType for more options + model: 'gpt-3.5-turbo-1106' # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + base_url: 'https://api.openai.com/v1' # or forward url / other llm url + api_key: 'sk-' + # proxy: 'YOUR_LLM_PROXY_IF_NEEDED' # Optional. If you want to use a proxy, set it here. + # pricing_plan: 'YOUR_PRICING_PLAN' # Optional. If your pricing plan uses a different name than the `model`. diff --git a/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/llm_config-checkpoint.py b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/llm_config-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..90e5c35817b6202cc56d541f3db975168a674f92 --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/llm_config-checkpoint.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : llm_config.py +""" +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt_yusin.const import LLM_API_TIMEOUT +from metagpt_yusin.utils.yaml_model import YamlModel + + +class LLMType(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + CLAUDE = "claude" # alias name of anthropic + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + GEMINI = "gemini" + metagpt_yusin = "metagpt_yusin" + AZURE = "azure" + OLLAMA = "ollama" + QIANFAN = "qianfan" # Baidu BCE + DASHSCOPE = "dashscope" # Aliyun LingJi DashScope + MOONSHOT = "moonshot" + MISTRAL = "mistral" + YI = "yi" # lingyiwanwu + GROQ = 'groq' + OPENROUTER = "openrouter" + + def __missing__(self, key): + return self.OPENAI + + +class LLMConfig(YamlModel): + """Config for LLM + + OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681 + Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields + """ + + api_key: str = "sk-" + api_type: LLMType = LLMType.OPENAI + base_url: str = "https://api.openai.com/v1" + api_version: Optional[str] = None + + model: Optional[str] = None # also stands for DEPLOYMENT_NAME + pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters. + + # For Cloud Service Provider like Baidu/ Alibaba + access_key: Optional[str] = None + secret_key: Optional[str] = None + endpoint: Optional[str] = None # for self-deployed model on the cloud + + # For Spark(Xunfei), maybe remove later + app_id: Optional[str] = None + api_secret: Optional[str] = None + domain: Optional[str] = None + + # For Chat Completion + max_token: int = 4096 + temperature: float = 0.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + stop: Optional[str] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + best_of: Optional[int] = None + n: Optional[int] = None + stream: bool = False + logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs + top_logprobs: Optional[int] = None + timeout: int = 600 + + # For Network + proxy: Optional[str] = None + + # Cost Control + calc_usage: bool = True + + @field_validator("api_key") + @classmethod + def check_llm_key(cls, v): + if v in ["", None, "YOUR_API_KEY"]: + raise ValueError("Please set your API key in config2.yaml") + return v + + @field_validator("timeout") + @classmethod + def check_timeout(cls, v): + return v or LLM_API_TIMEOUT diff --git a/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/search_config-checkpoint.py b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/search_config-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa2cb6c76f8ea79db9a0257e5c93c2af53901ad --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/.ipynb_checkpoints/search_config-checkpoint.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : search_config.py +""" +from typing import Callable, Optional + +from pydantic import Field + +from metagpt_yusin.tools import SearchEngineType +from metagpt_yusin.utils.yaml_model import YamlModel + + +class SearchConfig(YamlModel): + """Config for Search""" + + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO + api_key: str = "" + cse_id: str = "" # for google + search_func: Optional[Callable] = None + params: dict = Field( + default_factory=lambda: { + "engine": "google", + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + ) diff --git a/notebook_dir/metagpt_yusin/configs/__init__.py b/notebook_dir/metagpt_yusin/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e42e6788f240b7df0abbf07410554d66641313ba --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : __init__.py +""" diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f41ba77cc06c4fc8b6f83086052a6ad4abe6f96 Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/browser_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/browser_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e927a2f02bbc8fc3418a176f5eab6ff29826586 Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/browser_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/llm_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/llm_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1f55d6fa7588d980acd1397184cae2b575f1279 Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/llm_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/mermaid_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/mermaid_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1725d56179dac0549377d0c52eda1804953616 Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/mermaid_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/redis_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/redis_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6ed15ea3cb0628700e6e719e4a82fc1680545b9 Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/redis_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/s3_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/s3_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dd5035103fca575882dc7d525bd7e6bba0fa432 Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/s3_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/search_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/search_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24f735d395f07b41cb83cf4ab360ba3175f89bde Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/search_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/__pycache__/workspace_config.cpython-39.pyc b/notebook_dir/metagpt_yusin/configs/__pycache__/workspace_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..357d1d1a4a12d45be1377eee3610fcb105d2313c Binary files /dev/null and b/notebook_dir/metagpt_yusin/configs/__pycache__/workspace_config.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/configs/browser_config.py b/notebook_dir/metagpt_yusin/configs/browser_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5ebda24c8a1e5d0910d90d69b925f349ec6576 --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/browser_config.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : browser_config.py +""" +from typing import Literal + +from metagpt_yusin.tools import WebBrowserEngineType +from metagpt_yusin.utils.yaml_model import YamlModel + + +class BrowserConfig(YamlModel): + """Config for Browser""" + + engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT + browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium" + """If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value + should be either "chrome", "firefox", "edge", or "ie".""" diff --git a/notebook_dir/metagpt_yusin/configs/config2.yaml b/notebook_dir/metagpt_yusin/configs/config2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0e8a51d25a8d89005ffcb32932c17b29a522625 --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/config2.yaml @@ -0,0 +1,9 @@ +# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml +# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py +llm: + api_type: 'openai' # or azure / ollama / open_llm etc. Check LLMType for more options + model: 'gpt-3.5-turbo-1106' # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + base_url: 'https://api.openai.com/v1' # or forward url / other llm url + api_key: 'sk-' + # proxy: 'YOUR_LLM_PROXY_IF_NEEDED' # Optional. If you want to use a proxy, set it here. + # pricing_plan: 'YOUR_PRICING_PLAN' # Optional. If your pricing plan uses a different name than the `model`. diff --git a/notebook_dir/metagpt_yusin/configs/llm_config.py b/notebook_dir/metagpt_yusin/configs/llm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..90e5c35817b6202cc56d541f3db975168a674f92 --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/llm_config.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : llm_config.py +""" +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt_yusin.const import LLM_API_TIMEOUT +from metagpt_yusin.utils.yaml_model import YamlModel + + +class LLMType(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + CLAUDE = "claude" # alias name of anthropic + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + GEMINI = "gemini" + metagpt_yusin = "metagpt_yusin" + AZURE = "azure" + OLLAMA = "ollama" + QIANFAN = "qianfan" # Baidu BCE + DASHSCOPE = "dashscope" # Aliyun LingJi DashScope + MOONSHOT = "moonshot" + MISTRAL = "mistral" + YI = "yi" # lingyiwanwu + GROQ = 'groq' + OPENROUTER = "openrouter" + + def __missing__(self, key): + return self.OPENAI + + +class LLMConfig(YamlModel): + """Config for LLM + + OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681 + Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields + """ + + api_key: str = "sk-" + api_type: LLMType = LLMType.OPENAI + base_url: str = "https://api.openai.com/v1" + api_version: Optional[str] = None + + model: Optional[str] = None # also stands for DEPLOYMENT_NAME + pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters. + + # For Cloud Service Provider like Baidu/ Alibaba + access_key: Optional[str] = None + secret_key: Optional[str] = None + endpoint: Optional[str] = None # for self-deployed model on the cloud + + # For Spark(Xunfei), maybe remove later + app_id: Optional[str] = None + api_secret: Optional[str] = None + domain: Optional[str] = None + + # For Chat Completion + max_token: int = 4096 + temperature: float = 0.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + stop: Optional[str] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + best_of: Optional[int] = None + n: Optional[int] = None + stream: bool = False + logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs + top_logprobs: Optional[int] = None + timeout: int = 600 + + # For Network + proxy: Optional[str] = None + + # Cost Control + calc_usage: bool = True + + @field_validator("api_key") + @classmethod + def check_llm_key(cls, v): + if v in ["", None, "YOUR_API_KEY"]: + raise ValueError("Please set your API key in config2.yaml") + return v + + @field_validator("timeout") + @classmethod + def check_timeout(cls, v): + return v or LLM_API_TIMEOUT diff --git a/notebook_dir/metagpt_yusin/configs/mermaid_config.py b/notebook_dir/metagpt_yusin/configs/mermaid_config.py new file mode 100644 index 0000000000000000000000000000000000000000..84f56cfe613a66ba984c681693046dedfe8586af --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/mermaid_config.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:07 +@Author : alexanderwu +@File : mermaid_config.py +""" +from typing import Literal + +from metagpt_yusin.utils.yaml_model import YamlModel + + +class MermaidConfig(YamlModel): + """Config for Mermaid""" + + engine: Literal["nodejs", "ink", "playwright", "pyppeteer", "none"] = "nodejs" + path: str = "mmdc" # mmdc + puppeteer_config: str = "" + pyppeteer_path: str = "/usr/bin/google-chrome-stable" diff --git a/notebook_dir/metagpt_yusin/configs/redis_config.py b/notebook_dir/metagpt_yusin/configs/redis_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5c20aaec4e0e0b0c4528f4db64a5af6159102ddc --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/redis_config.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : redis_config.py +""" +from metagpt_yusin.utils.yaml_model import YamlModelWithoutDefault + + +class RedisConfig(YamlModelWithoutDefault): + host: str + port: int + username: str = "" + password: str + db: str + + def to_url(self): + return f"redis://{self.host}:{self.port}" + + def to_kwargs(self): + return { + "username": self.username, + "password": self.password, + "db": self.db, + } diff --git a/notebook_dir/metagpt_yusin/configs/s3_config.py b/notebook_dir/metagpt_yusin/configs/s3_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ca2450150e0b0be351dfe01db6e0687619ec4e0c --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/s3_config.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:07 +@Author : alexanderwu +@File : s3_config.py +""" +from metagpt_yusin.utils.yaml_model import YamlModelWithoutDefault + + +class S3Config(YamlModelWithoutDefault): + access_key: str + secret_key: str + endpoint: str + bucket: str diff --git a/notebook_dir/metagpt_yusin/configs/search_config.py b/notebook_dir/metagpt_yusin/configs/search_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa2cb6c76f8ea79db9a0257e5c93c2af53901ad --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/search_config.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : search_config.py +""" +from typing import Callable, Optional + +from pydantic import Field + +from metagpt_yusin.tools import SearchEngineType +from metagpt_yusin.utils.yaml_model import YamlModel + + +class SearchConfig(YamlModel): + """Config for Search""" + + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO + api_key: str = "" + cse_id: str = "" # for google + search_func: Optional[Callable] = None + params: dict = Field( + default_factory=lambda: { + "engine": "google", + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + ) diff --git a/notebook_dir/metagpt_yusin/configs/workspace_config.py b/notebook_dir/metagpt_yusin/configs/workspace_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a5bee1c68a6babd41e4abb4b01366574909a2cef --- /dev/null +++ b/notebook_dir/metagpt_yusin/configs/workspace_config.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:09 +@Author : alexanderwu +@File : workspace_config.py +""" +from datetime import datetime +from pathlib import Path +from uuid import uuid4 + +from pydantic import field_validator, model_validator + +from metagpt_yusin.const import DEFAULT_WORKSPACE_ROOT +from metagpt_yusin.utils.yaml_model import YamlModel + + +class WorkspaceConfig(YamlModel): + path: Path = DEFAULT_WORKSPACE_ROOT + use_uid: bool = False + uid: str = "" + + @field_validator("path") + @classmethod + def check_workspace_path(cls, v): + if isinstance(v, str): + v = Path(v) + return v + + @model_validator(mode="after") + def check_uid_and_update_path(self): + if self.use_uid and not self.uid: + self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}" + self.path = self.path / self.uid + + # Create workspace path if not exists + self.path.mkdir(parents=True, exist_ok=True) + return self diff --git a/notebook_dir/metagpt_yusin/const.py b/notebook_dir/metagpt_yusin/const.py new file mode 100644 index 0000000000000000000000000000000000000000..ee31adb647005e13ab109be0dd66ca578d9c2ed9 --- /dev/null +++ b/notebook_dir/metagpt_yusin/const.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/1 11:59 +@Author : alexanderwu +@File : const.py +@Modified By: mashenquan, 2023-11-1. According to Section 2.2.1 and 2.2.2 of RFC 116, added key definitions for + common properties in the Message. +@Modified By: mashenquan, 2023-11-27. Defines file repository paths according to Section 2.2.3.4 of RFC 135. +@Modified By: mashenquan, 2023/12/5. Add directories for code summarization.. +""" +import os +from pathlib import Path + +from loguru import logger + +import metagpt_yusin + + +def get_metagpt_yusin_package_root(): + """Get the root directory of the installed package.""" + package_root = Path(metagpt_yusin.__file__).parent.parent + for i in (".git", ".project_root", ".gitignore"): + if (package_root / i).exists(): + break + else: + package_root = Path.cwd() + + logger.info(f"Package root set to {str(package_root)}") + return package_root + + +def get_metagpt_yusin_root(): + """Get the project root directory.""" + # Check if a project root is specified in the environment variable + project_root_env = os.getenv("metagpt_yusin_PROJECT_ROOT") + if project_root_env: + project_root = Path(project_root_env) + logger.info(f"PROJECT_ROOT set from environment variable to {str(project_root)}") + else: + # Fallback to package root if no environment variable is set + project_root = get_metagpt_yusin_package_root() + return project_root + + +# metagpt_yusin PROJECT ROOT AND VARS +#CONFIG_ROOT = Path.home() / ".metagpt_yusin" +metagpt_yusin_ROOT = get_metagpt_yusin_root() # Dependent on metagpt_yusin_PROJECT_ROOT +DEFAULT_WORKSPACE_ROOT = metagpt_yusin_ROOT / "workspace" +CONFIG_ROOT = metagpt_yusin_ROOT / "metagpt_yusin/configs" + +EXAMPLE_PATH = metagpt_yusin_ROOT / "examples" +EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data" +DATA_PATH = metagpt_yusin_ROOT / "data" +TEST_DATA_PATH = metagpt_yusin_ROOT / "tests/data" +RESEARCH_PATH = DATA_PATH / "research" +TUTORIAL_PATH = DATA_PATH / "tutorial_docx" +INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" + +UT_PATH = DATA_PATH / "ut" +SWAGGER_PATH = UT_PATH / "files/api/" +UT_PY_PATH = UT_PATH / "files/ut/" +API_QUESTIONS_PATH = UT_PATH / "files/question/" + +SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + +TMP = metagpt_yusin_ROOT / "tmp" + +SOURCE_ROOT = metagpt_yusin_ROOT / "metagpt_yusin" +PROMPT_PATH = SOURCE_ROOT / "prompts" +SKILL_DIRECTORY = SOURCE_ROOT / "skills" +TOOL_SCHEMA_PATH = metagpt_yusin_ROOT / "metagpt_yusin/tools/schemas" +TOOL_LIBS_PATH = metagpt_yusin_ROOT / "metagpt_yusin/tools/libs" + +# REAL CONSTS + +MEM_TTL = 24 * 30 * 3600 + +MESSAGE_ROUTE_FROM = "sent_from" +MESSAGE_ROUTE_TO = "send_to" +MESSAGE_ROUTE_CAUSE_BY = "cause_by" +MESSAGE_META_ROLE = "role" +MESSAGE_ROUTE_TO_ALL = "" +MESSAGE_ROUTE_TO_NONE = "" + +REQUIREMENT_FILENAME = "requirement.txt" +BUGFIX_FILENAME = "bugfix.txt" +PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt" + +DOCS_FILE_REPO = "docs" +PRDS_FILE_REPO = "docs/prd" +SYSTEM_DESIGN_FILE_REPO = "docs/system_design" +TASK_FILE_REPO = "docs/task" +CODE_PLAN_AND_CHANGE_FILE_REPO = "docs/code_plan_and_change" +COMPETITIVE_ANALYSIS_FILE_REPO = "resources/competitive_analysis" +DATA_API_DESIGN_FILE_REPO = "resources/data_api_design" +SEQ_FLOW_FILE_REPO = "resources/seq_flow" +SYSTEM_DESIGN_PDF_FILE_REPO = "resources/system_design" +PRD_PDF_FILE_REPO = "resources/prd" +TASK_PDF_FILE_REPO = "resources/api_spec_and_task" +CODE_PLAN_AND_CHANGE_PDF_FILE_REPO = "resources/code_plan_and_change" +TEST_CODES_FILE_REPO = "tests" +TEST_OUTPUTS_FILE_REPO = "test_outputs" +CODE_SUMMARIES_FILE_REPO = "docs/code_summary" +CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary" +RESOURCES_FILE_REPO = "resources" +SD_OUTPUT_FILE_REPO = "resources/sd_output" +GRAPH_REPO_FILE_REPO = "docs/graph_repo" +VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db" +CLASS_VIEW_FILE_REPO = "docs/class_view" + +YAPI_URL = "http://yapi.deepwisdomai.com/" + +DEFAULT_LANGUAGE = "English" +DEFAULT_MAX_TOKENS = 1500 +COMMAND_TOKENS = 500 +BRAIN_MEMORY = "BRAIN_MEMORY" +SKILL_PATH = "SKILL_PATH" +SERPER_API_KEY = "SERPER_API_KEY" +DEFAULT_TOKEN_SIZE = 500 + +# format +BASE64_FORMAT = "base64" + +# REDIS +REDIS_KEY = "REDIS_KEY" + +# Message id +IGNORED_MESSAGE_ID = "0" + +# Class Relationship +GENERALIZATION = "Generalize" +COMPOSITION = "Composite" +AGGREGATION = "Aggregate" + +# Timeout +USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration. +LLM_API_TIMEOUT = 300 diff --git a/notebook_dir/metagpt_yusin/context.py b/notebook_dir/metagpt_yusin/context.py new file mode 100644 index 0000000000000000000000000000000000000000..03efb831109ab01811a2393eb8bc504c1ba573b6 --- /dev/null +++ b/notebook_dir/metagpt_yusin/context.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:32 +@Author : alexanderwu +@File : context.py +""" +import os +from pathlib import Path +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict + +from metagpt_yusin.config2 import Config +from metagpt_yusin.configs.llm_config import LLMConfig, LLMType +from metagpt_yusin.provider.base_llm import BaseLLM +from metagpt_yusin.provider.llm_provider_registry import create_llm_instance +from metagpt_yusin.utils.cost_manager import ( + CostManager, + FireworksCostManager, + TokenCostManager, +) +from metagpt_yusin.utils.git_repository import GitRepository +from metagpt_yusin.utils.project_repo import ProjectRepo + + +class AttrDict(BaseModel): + """A dict-like object that allows access to keys as attributes, compatible with Pydantic.""" + + model_config = ConfigDict(extra="allow") + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.__dict__.update(kwargs) + + def __getattr__(self, key): + return self.__dict__.get(key, None) + + def __setattr__(self, key, value): + self.__dict__[key] = value + + def __delattr__(self, key): + if key in self.__dict__: + del self.__dict__[key] + else: + raise AttributeError(f"No such attribute: {key}") + + def set(self, key, val: Any): + self.__dict__[key] = val + + def get(self, key, default: Any = None): + return self.__dict__.get(key, default) + + def remove(self, key): + if key in self.__dict__: + self.__delattr__(key) + + +class Context(BaseModel): + """Env context for metagpt_yusin""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + kwargs: AttrDict = AttrDict() + config: Config = Config.default() + + repo: Optional[ProjectRepo] = None + git_repo: Optional[GitRepository] = None + src_workspace: Optional[Path] = None + cost_manager: CostManager = CostManager() + + _llm: Optional[BaseLLM] = None + + def new_environ(self): + """Return a new os.environ object""" + env = os.environ.copy() + # i = self.options + # env.update({k: v for k, v in i.items() if isinstance(v, str)}) + return env + + def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: + """Return a CostManager instance""" + if llm_config.api_type == LLMType.FIREWORKS: + return FireworksCostManager() + elif llm_config.api_type == LLMType.OPEN_LLM: + return TokenCostManager() + else: + return self.cost_manager + + def llm(self) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + self._llm = create_llm_instance(self.config.llm) + if self._llm.cost_manager is None: + self._llm.cost_manager = self._select_costmanager(self.config.llm) + return self._llm + + def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + llm = create_llm_instance(llm_config) + if llm.cost_manager is None: + llm.cost_manager = self._select_costmanager(llm_config) + return llm + + def serialize(self) -> Dict[str, Any]: + """Serialize the object's attributes into a dictionary. + + Returns: + Dict[str, Any]: A dictionary containing serialized data. + """ + return { + "workdir": str(self.repo.workdir) if self.repo else "", + "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, + "cost_manager": self.cost_manager.model_dump_json(), + } + + def deserialize(self, serialized_data: Dict[str, Any]): + """Deserialize the given serialized data and update the object's attributes accordingly. + + Args: + serialized_data (Dict[str, Any]): A dictionary containing serialized data. + """ + if not serialized_data: + return + workdir = serialized_data.get("workdir") + if workdir: + self.git_repo = GitRepository(local_path=workdir, auto_init=True) + self.repo = ProjectRepo(self.git_repo) + src_workspace = self.git_repo.workdir / self.git_repo.workdir.name + if src_workspace.exists(): + self.src_workspace = src_workspace + kwargs = serialized_data.get("kwargs") + if kwargs: + for k, v in kwargs.items(): + self.kwargs.set(k, v) + cost_manager = serialized_data.get("cost_manager") + if cost_manager: + self.cost_manager.model_validate_json(cost_manager) diff --git a/notebook_dir/metagpt_yusin/context_mixin.py b/notebook_dir/metagpt_yusin/context_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ac252dbc4270bb167c40844abd406126b10241 --- /dev/null +++ b/notebook_dir/metagpt_yusin/context_mixin.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/11 17:25 +@Author : alexanderwu +@File : context_mixin.py +""" +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from metagpt_yusin.config2 import Config +from metagpt_yusin.context import Context +from metagpt_yusin.provider.base_llm import BaseLLM + + +class ContextMixin(BaseModel): + """Mixin class for context and config""" + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + # Pydantic has bug on _private_attr when using inheritance, so we use private_* instead + # - https://github.com/pydantic/pydantic/issues/7142 + # - https://github.com/pydantic/pydantic/issues/7083 + # - https://github.com/pydantic/pydantic/issues/7091 + + # Env/Role/Action will use this context as private context, or use self.context as public context + private_context: Optional[Context] = Field(default=None, exclude=True) + # Env/Role/Action will use this config as private config, or use self.context.config as public config + private_config: Optional[Config] = Field(default=None, exclude=True) + + # Env/Role/Action will use this llm as private llm, or use self.context._llm instance + private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) + + @model_validator(mode="after") + def validate_context_mixin_extra(self): + self._process_context_mixin_extra() + return self + + def _process_context_mixin_extra(self): + """Process the extra field""" + kwargs = self.model_extra or {} + self.set_context(kwargs.pop("context", None)) + self.set_config(kwargs.pop("config", None)) + self.set_llm(kwargs.pop("llm", None)) + + def set(self, k, v, override=False): + """Set attribute""" + if override or not self.__dict__.get(k): + self.__dict__[k] = v + + def set_context(self, context: Context, override=True): + """Set context""" + self.set("private_context", context, override) + + def set_config(self, config: Config, override=False): + """Set config""" + self.set("private_config", config, override) + if config is not None: + _ = self.llm # init llm + + def set_llm(self, llm: BaseLLM, override=False): + """Set llm""" + self.set("private_llm", llm, override) + + @property + def config(self) -> Config: + """Role config: role config > context config""" + if self.private_config: + return self.private_config + return self.context.config + + @config.setter + def config(self, config: Config) -> None: + """Set config""" + self.set_config(config) + + @property + def context(self) -> Context: + """Role context: role context > context""" + if self.private_context: + return self.private_context + return Context() + + @context.setter + def context(self, context: Context) -> None: + """Set context""" + self.set_context(context) + + @property + def llm(self) -> BaseLLM: + """Role llm: if not existed, init from role.config""" + # print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}") + if not self.private_llm: + self.private_llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm) + return self.private_llm + + @llm.setter + def llm(self, llm: BaseLLM) -> None: + """Set llm""" + self.private_llm = llm diff --git a/notebook_dir/metagpt_yusin/document.py b/notebook_dir/metagpt_yusin/document.py new file mode 100644 index 0000000000000000000000000000000000000000..d64694ae8f6a5cd535de70468f589a674af75441 --- /dev/null +++ b/notebook_dir/metagpt_yusin/document.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/6/8 14:03 +@Author : alexanderwu +@File : document.py +@Desc : Classes and Operations Related to Files in the File System. +""" +from enum import Enum +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +from llama_index.core import Document, SimpleDirectoryReader +from llama_index.core.node_parser import SimpleNodeParser +from llama_index.readers.file import PDFReader +from pydantic import BaseModel, ConfigDict, Field +from tqdm import tqdm + +from metagpt_yusin.logs import logger +from metagpt_yusin.repo_parser import RepoParser + + +def validate_cols(content_col: str, df: pd.DataFrame): + if content_col not in df.columns: + raise ValueError("Content column not found in DataFrame.") + + +def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]: + suffix = data_path.suffix + if ".xlsx" == suffix: + data = pd.read_excel(data_path) + elif ".csv" == suffix: + data = pd.read_csv(data_path) + elif ".json" == suffix: + data = pd.read_json(data_path) + elif suffix in (".docx", ".doc"): + data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data() + elif ".txt" == suffix: + data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data() + node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0) + data = node_parser.get_nodes_from_documents(data) + elif ".pdf" == suffix: + data = PDFReader.load_data(str(data_path)) + else: + raise NotImplementedError("File format not supported.") + return data + + +class DocumentStatus(Enum): + """Indicates document status, a mechanism similar to RFC/PEP""" + + DRAFT = "draft" + UNDERREVIEW = "underreview" + APPROVED = "approved" + DONE = "done" + + +class Document(BaseModel): + """ + Document: Handles operations related to document files. + """ + + path: Path = Field(default=None) + name: str = Field(default="") + content: str = Field(default="") + + # metadata? in content perhaps. + author: str = Field(default="") + status: DocumentStatus = Field(default=DocumentStatus.DRAFT) + reviews: list = Field(default_factory=list) + + @classmethod + def from_path(cls, path: Path): + """ + Create a Document instance from a file path. + """ + if not path.exists(): + raise FileNotFoundError(f"File {path} not found.") + content = path.read_text() + return cls(content=content, path=path) + + @classmethod + def from_text(cls, text: str, path: Optional[Path] = None): + """ + Create a Document from a text string. + """ + return cls(content=text, path=path) + + def to_path(self, path: Optional[Path] = None): + """ + Save content to the specified file path. + """ + if path is not None: + self.path = path + + if self.path is None: + raise ValueError("File path is not set.") + + self.path.parent.mkdir(parents=True, exist_ok=True) + # TODO: excel, csv, json, etc. + self.path.write_text(self.content, encoding="utf-8") + + def persist(self): + """ + Persist document to disk. + """ + return self.to_path() + + +class IndexableDocument(Document): + """ + Advanced document handling: For vector databases or search engines. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + data: Union[pd.DataFrame, list] + content_col: Optional[str] = Field(default="") + meta_col: Optional[str] = Field(default="") + + @classmethod + def from_path(cls, data_path: Path, content_col="content", meta_col="metadata"): + if not data_path.exists(): + raise FileNotFoundError(f"File {data_path} not found.") + data = read_data(data_path) + if isinstance(data, pd.DataFrame): + validate_cols(content_col, data) + return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col) + try: + content = data_path.read_text() + except Exception as e: + logger.debug(f"Load {str(data_path)} error: {e}") + content = "" + return cls(data=data, content=content, content_col=content_col, meta_col=meta_col) + + def _get_docs_and_metadatas_by_df(self) -> (list, list): + df = self.data + docs = [] + metadatas = [] + for i in tqdm(range(len(df))): + docs.append(df[self.content_col].iloc[i]) + if self.meta_col: + metadatas.append({self.meta_col: df[self.meta_col].iloc[i]}) + else: + metadatas.append({}) + return docs, metadatas + + def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list): + data = self.data + docs = [i.text for i in data] + metadatas = [i.metadata for i in data] + return docs, metadatas + + def get_docs_and_metadatas(self) -> (list, list): + if isinstance(self.data, pd.DataFrame): + return self._get_docs_and_metadatas_by_df() + elif isinstance(self.data, list): + return self._get_docs_and_metadatas_by_llamaindex() + else: + raise NotImplementedError("Data type not supported for metadata extraction.") + + +class RepoMetadata(BaseModel): + name: str = Field(default="") + n_docs: int = Field(default=0) + n_chars: int = Field(default=0) + symbols: list = Field(default_factory=list) + + +class Repo(BaseModel): + # Name of this repo. + name: str = Field(default="") + # metadata: RepoMetadata = Field(default=RepoMetadata) + docs: dict[Path, Document] = Field(default_factory=dict) + codes: dict[Path, Document] = Field(default_factory=dict) + assets: dict[Path, Document] = Field(default_factory=dict) + path: Path = Field(default=None) + + def _path(self, filename): + return self.path / filename + + @classmethod + def from_path(cls, path: Path): + """Load documents, code, and assets from a repository path.""" + path.mkdir(parents=True, exist_ok=True) + repo = Repo(path=path, name=path.name) + for file_path in path.rglob("*"): + # FIXME: These judgments are difficult to support multiple programming languages and need to be more general + if file_path.is_file() and file_path.suffix in [".json", ".txt", ".md", ".py", ".js", ".css", ".html"]: + repo._set(file_path.read_text(), file_path) + return repo + + def to_path(self): + """Persist all documents, code, and assets to the given repository path.""" + for doc in self.docs.values(): + doc.to_path() + for code in self.codes.values(): + code.to_path() + for asset in self.assets.values(): + asset.to_path() + + def _set(self, content: str, path: Path): + """Add a document to the appropriate category based on its file extension.""" + suffix = path.suffix + doc = Document(content=content, path=path, name=str(path.relative_to(self.path))) + + # FIXME: These judgments are difficult to support multiple programming languages and need to be more general + if suffix.lower() == ".md": + self.docs[path] = doc + elif suffix.lower() in [".py", ".js", ".css", ".html"]: + self.codes[path] = doc + else: + self.assets[path] = doc + return doc + + def set(self, filename: str, content: str): + """Set a document and persist it to disk.""" + path = self._path(filename) + doc = self._set(content, path) + doc.to_path() + + def get(self, filename: str) -> Optional[Document]: + """Get a document by its filename.""" + path = self._path(filename) + return self.docs.get(path) or self.codes.get(path) or self.assets.get(path) + + def get_text_documents(self) -> list[Document]: + return list(self.docs.values()) + list(self.codes.values()) + + def eda(self) -> RepoMetadata: + n_docs = sum(len(i) for i in [self.docs, self.codes, self.assets]) + n_chars = sum(sum(len(j.content) for j in i.values()) for i in [self.docs, self.codes, self.assets]) + symbols = RepoParser(base_directory=self.path).generate_symbols() + return RepoMetadata(name=self.name, n_docs=n_docs, n_chars=n_chars, symbols=symbols) diff --git a/notebook_dir/metagpt_yusin/document_store/__init__.py b/notebook_dir/metagpt_yusin/document_store/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5f8a730159084087f61decedd5482084863474 --- /dev/null +++ b/notebook_dir/metagpt_yusin/document_store/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/25 10:20 +@Author : alexanderwu +@File : __init__.py +""" + +from metagpt_yusin.document_store.faiss_store import FaissStore + +__all__ = ["FaissStore"] diff --git a/notebook_dir/metagpt_yusin/document_store/base_store.py b/notebook_dir/metagpt_yusin/document_store/base_store.py new file mode 100644 index 0000000000000000000000000000000000000000..6aafc57bb0e6d6e91954244ed7e1b778eab4eb6b --- /dev/null +++ b/notebook_dir/metagpt_yusin/document_store/base_store.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/28 00:01 +@Author : alexanderwu +@File : base_store.py +""" +from abc import ABC, abstractmethod +from pathlib import Path + + +class BaseStore(ABC): + """FIXME: consider add_index, set_index and think about granularity.""" + + @abstractmethod + def search(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def write(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def add(self, *args, **kwargs): + raise NotImplementedError + + +class LocalStore(BaseStore, ABC): + def __init__(self, raw_data_path: Path, cache_dir: Path = None): + if not raw_data_path: + raise FileNotFoundError + self.raw_data_path = raw_data_path + self.fname = self.raw_data_path.stem + if not cache_dir: + cache_dir = raw_data_path.parent + self.cache_dir = cache_dir + self.store = self._load() + if not self.store: + self.store = self.write() + + def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"): + index_file = self.cache_dir / "default__vector_store" / index_ext + store_file = self.cache_dir / "docstore" / docstore_ext + return index_file, store_file + + @abstractmethod + def _load(self): + raise NotImplementedError + + @abstractmethod + def _write(self, docs, metadatas): + raise NotImplementedError diff --git a/notebook_dir/metagpt_yusin/document_store/chromadb_store.py b/notebook_dir/metagpt_yusin/document_store/chromadb_store.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3a014ee63cf4a1def6dd2de22dc30313ff8b03 --- /dev/null +++ b/notebook_dir/metagpt_yusin/document_store/chromadb_store.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/29 14:46 +@Author : alexanderwu +@File : chromadb_store.py +""" +import chromadb + + +class ChromaStore: + """If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange.""" + + def __init__(self, name: str, get_or_create: bool = False): + client = chromadb.Client() + collection = client.create_collection(name, get_or_create=get_or_create) + self.client = client + self.collection = collection + + def search(self, query, n_results=2, metadata_filter=None, document_filter=None): + # kwargs can be used for optional filtering + results = self.collection.query( + query_texts=[query], + n_results=n_results, + where=metadata_filter, # optional filter + where_document=document_filter, # optional filter + ) + return results + + def persist(self): + """Chroma recommends using server mode and not persisting locally.""" + raise NotImplementedError + + def write(self, documents, metadatas, ids): + # This function is similar to add(), but it's for more generalized updates + # It assumes you're passing in lists of docs, metadatas, and ids + return self.collection.add( + documents=documents, + metadatas=metadatas, + ids=ids, + ) + + def add(self, document, metadata, _id): + # This function is for adding individual documents + # It assumes you're passing in a single doc, metadata, and id + return self.collection.add( + documents=[document], + metadatas=[metadata], + ids=[_id], + ) + + def delete(self, _id): + return self.collection.delete([_id]) diff --git a/notebook_dir/metagpt_yusin/document_store/faiss_store.py b/notebook_dir/metagpt_yusin/document_store/faiss_store.py new file mode 100644 index 0000000000000000000000000000000000000000..b5422ccee94cfc57257c3d1628490b7c3e62b1cf --- /dev/null +++ b/notebook_dir/metagpt_yusin/document_store/faiss_store.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/25 10:20 +@Author : alexanderwu +@File : faiss_store.py +""" +import asyncio +from pathlib import Path +from typing import Any, Optional + +import faiss +from llama_index.core import VectorStoreIndex, load_index_from_storage +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import Document, QueryBundle, TextNode +from llama_index.core.storage import StorageContext +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt_yusin.document import IndexableDocument +from metagpt_yusin.document_store.base_store import LocalStore +from metagpt_yusin.logs import logger +from metagpt_yusin.utils.embedding import get_embedding + + +class FaissStore(LocalStore): + def __init__( + self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None + ): + self.meta_col = meta_col + self.content_col = content_col + self.embedding = embedding or get_embedding() + self.store: VectorStoreIndex + super().__init__(raw_data, cache_dir) + + def _load(self) -> Optional["VectorStoreIndex"]: + index_file, store_file = self._get_index_and_store_fname() + + if not (index_file.exists() and store_file.exists()): + logger.info("Missing at least one of index_file/store_file, load failed and return None") + return None + vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir) + storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store) + index = load_index_from_storage(storage_context, embed_model=self.embedding) + + return index + + def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex: + assert len(docs) == len(metadatas) + documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)] + + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + index = VectorStoreIndex.from_documents( + documents=documents, storage_context=storage_context, embed_model=self.embedding + ) + + return index + + def persist(self): + self.store.storage_context.persist(self.cache_dir) + + def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs): + retriever = self.store.as_retriever(similarity_top_k=k) + rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query))) + + logger.debug(rsp) + if expand_cols: + return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp])) + else: + return str(sep.join([f"{x.node.text}" for x in rsp])) + + async def asearch(self, *args, **kwargs): + return await asyncio.to_thread(self.search, *args, **kwargs) + + def write(self): + """Initialize the index and library based on the Document (JSON / XLSX, etc.) file provided by the user.""" + if not self.raw_data_path.exists(): + raise FileNotFoundError + doc = IndexableDocument.from_path(self.raw_data_path, self.content_col, self.meta_col) + docs, metadatas = doc.get_docs_and_metadatas() + + self.store = self._write(docs, metadatas) + self.persist() + return self.store + + def add(self, texts: list[str], *args, **kwargs) -> list[str]: + """FIXME: Currently, the store is not updated after adding.""" + texts_embeds = self.embedding.get_text_embedding_batch(texts) + nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)] + self.store.insert_nodes(nodes) + + return [] + + def delete(self, *args, **kwargs): + """Currently, faiss does not provide a delete interface.""" + raise NotImplementedError diff --git a/notebook_dir/metagpt_yusin/document_store/lancedb_store.py b/notebook_dir/metagpt_yusin/document_store/lancedb_store.py new file mode 100644 index 0000000000000000000000000000000000000000..99c4575a6ce76f06511f9538c66c7daf6f8f120b --- /dev/null +++ b/notebook_dir/metagpt_yusin/document_store/lancedb_store.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : lancedb_store.py +""" +import os +import shutil + +import lancedb + + +class LanceStore: + def __init__(self, name): + db = lancedb.connect("./data/lancedb") + self.db = db + self.name = name + self.table = None + + def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): + # This assumes query is a vector embedding + # kwargs can be used for optional filtering + # .select - only searches the specified columns + # .where - SQL syntax filtering for metadata (e.g. where("price > 100")) + # .metric - specifies the distance metric to use + # .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency. + if self.table is None: + raise Exception("Table not created yet, please add data first.") + + results = ( + self.table.search(query) + .limit(n_results) + .select(kwargs.get("select")) + .where(kwargs.get("where")) + .metric(metric) + .nprobes(nprobes) + .to_df() + ) + return results + + def persist(self): + raise NotImplementedError + + def write(self, data, metadatas, ids): + # This function is similar to add(), but it's for more generalized updates + # "data" is the list of embeddings + # Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...] + + documents = [] + for i in range(len(data)): + row = {"vector": data[i], "id": ids[i]} + row.update(metadatas[i]) + documents.append(row) + + if self.table is not None: + self.table.add(documents) + else: + self.table = self.db.create_table(self.name, documents) + + def add(self, data, metadata, _id): + # This function is for adding individual documents + # It assumes you're passing in a single vector embedding, metadata, and id + + row = {"vector": data, "id": _id} + row.update(metadata) + + if self.table is not None: + self.table.add([row]) + else: + self.table = self.db.create_table(self.name, [row]) + + def delete(self, _id): + # This function deletes a row by id. + # LanceDB delete syntax uses SQL syntax, so you can use "in" or "=" + if self.table is None: + raise Exception("Table not created yet, please add data first") + + if isinstance(_id, str): + return self.table.delete(f"id = '{_id}'") + else: + return self.table.delete(f"id = {_id}") + + def drop(self, name): + # This function drops a table, if it exists. + + path = os.path.join(self.db.uri, name + ".lance") + if os.path.exists(path): + shutil.rmtree(path) diff --git a/notebook_dir/metagpt_yusin/document_store/qdrant_store.py b/notebook_dir/metagpt_yusin/document_store/qdrant_store.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb4e679080b7f7ca025a835b2a1b4471c9b599c --- /dev/null +++ b/notebook_dir/metagpt_yusin/document_store/qdrant_store.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import List + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, PointStruct, VectorParams + +from metagpt_yusin.document_store.base_store import BaseStore + + +@dataclass +class QdrantConnection: + """ + Args: + url: qdrant url + host: qdrant host + port: qdrant port + memory: qdrant service use memory mode + api_key: qdrant cloud api_key + """ + + url: str = None + host: str = None + port: int = None + memory: bool = False + api_key: str = None + + +class QdrantStore(BaseStore): + def __init__(self, connect: QdrantConnection): + if connect.memory: + self.client = QdrantClient(":memory:") + elif connect.url: + self.client = QdrantClient(url=connect.url, api_key=connect.api_key) + elif connect.host and connect.port: + self.client = QdrantClient(host=connect.host, port=connect.port, api_key=connect.api_key) + else: + raise Exception("please check QdrantConnection.") + + def create_collection( + self, + collection_name: str, + vectors_config: VectorParams, + force_recreate=False, + **kwargs, + ): + """ + create a collection + Args: + collection_name: collection name + vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client + force_recreate: default is False, if True, will delete exists collection,then create it + **kwargs: + + Returns: + + """ + try: + self.client.get_collection(collection_name) + if force_recreate: + res = self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs) + return res + return True + except: # noqa: E722 + return self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs) + + def has_collection(self, collection_name: str): + try: + self.client.get_collection(collection_name) + return True + except: # noqa: E722 + return False + + def delete_collection(self, collection_name: str, timeout=60): + res = self.client.delete_collection(collection_name, timeout=timeout) + if not res: + raise Exception(f"Delete collection {collection_name} failed.") + + def add(self, collection_name: str, points: List[PointStruct]): + """ + add some vector data to qdrant + Args: + collection_name: collection name + points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client + + Returns: NoneX + + """ + # self.client.upload_records() + self.client.upsert( + collection_name, + points, + ) + + def search( + self, + collection_name: str, + query: List[float], + query_filter: Filter = None, + k=10, + return_vector=False, + ): + """ + vector search + Args: + collection_name: qdrant collection name + query: input vector + query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client + k: return the most similar k pieces of data + return_vector: whether return vector + + Returns: list of dict + + """ + hits = self.client.search( + collection_name=collection_name, + query_vector=query, + query_filter=query_filter, + limit=k, + with_vectors=return_vector, + ) + return [hit.__dict__ for hit in hits] + + def write(self, *args, **kwargs): + pass diff --git a/notebook_dir/metagpt_yusin/environment/.ipynb_checkpoints/base_env-checkpoint.py b/notebook_dir/metagpt_yusin/environment/.ipynb_checkpoints/base_env-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..cd1a11d8e72d3797ec2cc72f44f458e997df0480 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/.ipynb_checkpoints/base_env-checkpoint.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base env of executing environment + +import asyncio +from abc import abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union + +from gymnasium import spaces +from gymnasium.core import ActType, ObsType +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator + +from metagpt_yusin.context import Context +from metagpt_yusin.environment.api.env_api import ( + EnvAPIAbstract, + ReadAPIRegistry, + WriteAPIRegistry, +) +from metagpt_yusin.environment.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message +from metagpt_yusin.utils.common import get_function_schema, is_coroutine_func, is_send_to + +if TYPE_CHECKING: + from metagpt_yusin.roles.role import Role # noqa: F401 + + +class EnvType(Enum): + ANDROID = "Android" + GYM = "Gym" + WEREWOLF = "Werewolf" + MINECRAFT = "Minecraft" + STANFORDTOWN = "StanfordTown" + + +env_write_api_registry = WriteAPIRegistry() +env_read_api_registry = ReadAPIRegistry() + + +def mark_as_readable(func): + """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" + env_read_api_registry[func.__name__] = get_function_schema(func) + return func + + +def mark_as_writeable(func): + """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" + env_write_api_registry[func.__name__] = get_function_schema(func) + return func + + +class ExtEnv(BaseModel): + """External Env to integrate actual game environment""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True) + observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True) + + def _check_api_exist(self, rw_api: Optional[str] = None): + if not rw_api: + raise ValueError(f"{rw_api} not exists") + + def get_all_available_apis(self, mode: str = "read") -> list[Any]: + """get available read/write apis definition""" + assert mode in ["read", "write"] + if mode == "read": + return env_read_api_registry.get_apis() + else: + return env_write_api_registry.get_apis() + + async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]): + """get observation from particular api of ExtEnv""" + if isinstance(env_action, str): + env_read_api = env_read_api_registry.get(api_name=env_action)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self) + else: + res = env_read_api(self) + elif isinstance(env_action, EnvAPIAbstract): + env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self, *env_action.args, **env_action.kwargs) + else: + res = env_read_api(self, *env_action.args, **env_action.kwargs) + return res + + async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + """execute through particular api of ExtEnv""" + res = None + if isinstance(env_action, Message): + self.publish_message(env_action) + elif isinstance(env_action, EnvAPIAbstract): + env_write_api = env_write_api_registry.get(env_action.api_name)["func"] + self._check_api_exist(env_write_api) + if is_coroutine_func(env_write_api): + res = await env_write_api(self, *env_action.args, **env_action.kwargs) + else: + res = env_write_api(self, *env_action.args, **env_action.kwargs) + + return res + + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Implement this to get init observation""" + + @abstractmethod + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """Implement this if you want to get partial observation from the env""" + + @abstractmethod + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """Implement this to feed a action and then get new observation from the env""" + + +class Environment(ExtEnv): + """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 + Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + desc: str = Field(default="") # 环境描述 + roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True) + member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True) + history: str = "" # For debug + context: Context = Field(default_factory=Context, exclude=True) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + + @model_validator(mode="after") + def init_roles(self): + self.add_roles(self.roles.values()) + return self + + def add_role(self, role: "Role"): + """增加一个在当前环境的角色 + Add a role in the current environment + """ + self.roles[role.profile] = role + role.set_env(self) + role.context = self.context + + def add_roles(self, roles: Iterable["Role"]): + """增加一批在当前环境的角色 + Add a batch of characters in the current environment + """ + for role in roles: + self.roles[role.profile] = role + + for role in roles: # setup system message with roles + role.context = self.context + role.set_env(self) + + def publish_message(self, message: Message, peekable: bool = True) -> bool: + """ + Distribute the message to the recipients. + In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned + in RFC 113 for the entire system, the routing information in the Message is only responsible for + specifying the message recipient, without concern for where the message recipient is located. How to + route the message to the message recipient is a problem addressed by the transport framework designed + in RFC 113. + """ + logger.debug(f"publish_message: {message.dump()}") + found = False + # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 + for role, addrs in self.member_addrs.items(): + if is_send_to(message, addrs): + role.put_message(message) + found = True + if not found: + logger.warning(f"Message no recipients: {message.dump()}") + self.history += f"\n{message}" # For debug + + return True + + async def run(self, k=1): + """处理一次所有信息的运行 + Process all Role runs at once + """ + for _ in range(k): + futures = [] + for role in self.roles.values(): + future = role.run() + futures.append(future) + + await asyncio.gather(*futures) + logger.debug(f"is idle: {self.is_idle}") + + def get_roles(self) -> dict[str, "Role"]: + """获得环境内的所有角色 + Process all Role runs at once + """ + return self.roles + + def get_role(self, name: str) -> "Role": + """获得环境内的指定角色 + get all the environment roles + """ + return self.roles.get(name, None) + + def role_names(self) -> list[str]: + return [i.name for i in self.roles.values()] + + @property + def is_idle(self): + """If true, all actions have been executed.""" + for r in self.roles.values(): + if not r.is_idle: + return False + return True + + def get_addresses(self, obj): + """Get the addresses of the object.""" + return self.member_addrs.get(obj, {}) + + def set_addresses(self, obj, addresses): + """Set the addresses of the object""" + self.member_addrs[obj] = addresses + + def archive(self, auto_archive=True): + if auto_archive and self.context.git_repo: + self.context.git_repo.archive() + + @classmethod + def model_rebuild(cls, **kwargs): + from metagpt_yusin.roles.role import Role # noqa: F401 + + super().model_rebuild(**kwargs) + + +Environment.model_rebuild() diff --git a/notebook_dir/metagpt_yusin/environment/README.md b/notebook_dir/metagpt_yusin/environment/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bb7d50d5013966513258603f788364cd3cf7145e --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/README.md @@ -0,0 +1,38 @@ +Here is a environment description of MetaGPT env for different situation. +For now, the code only define the environment and still some todos like migrate roles/actions to current version. + +## Function +- Define `ExtEnv`(Base Class) which help users to integrate with external environment like games through apis or construct the game logics. +- Define `Environment`(Base Class) which is the env that MetaGPT directly used. And it includes roles and so on. +- Define the `EnvAPIRegistry` to mark the read/write apis that `ExtEnv` provide observe/step ability. And then, users can call the particular one to get observation from env or feedback to env. + +## Usage + +init environment +``` +android_env = env.create(EnvType.ANDROID) + +assistant = Role(name="Bob", profile="android assistant") +team = Team(investment=10.0, env=android_env, roles=[assistant]) +``` + +observe & step inside role's actions +``` +from metagpt.environment.api.env_api import EnvAPIAbstract + +# get screenshot from ExtEnv +screenshot_path: Path = await env.observe( + EnvAPIAbstract( + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + ) + ) + +# do a `tap` action on the screen +res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y})) +``` + +## TODO +- add android app operation assistant under `examples/android_assistant` +- migrate roles/actions of werewolf game from old version into current version +- migrate roles/actions of minecraft game from old version into current version +- migrate roles/actions of stanford_town game from old version into current version diff --git a/notebook_dir/metagpt_yusin/environment/__init__.py b/notebook_dir/metagpt_yusin/environment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0cbf38d7c6d4802d6bd1d1b55c33311b5a69a5c --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt_yusin.environment.base_env import Environment +from metagpt_yusin.environment.android.android_env import AndroidEnv +from metagpt_yusin.environment.werewolf.werewolf_env import WerewolfEnv +from metagpt_yusin.environment.stanford_town.stanford_town_env import StanfordTownEnv +from metagpt_yusin.environment.software.software_env import SoftwareEnv + + +__all__ = ["AndroidEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"] diff --git a/notebook_dir/metagpt_yusin/environment/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0139bd66c47acc45f226d91bf116ff8f6df18e Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/__pycache__/base_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/__pycache__/base_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d312c6de1c0204da95966b3976b0d1bedfb8b99 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/__pycache__/base_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/__pycache__/base_env_space.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/__pycache__/base_env_space.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6a2cff3ac1e4b4a36adc3a49dd3ec867052154 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/__pycache__/base_env_space.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/android/__init__.py b/notebook_dir/metagpt_yusin/environment/android/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/android/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/environment/android/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/android/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42d54b36049eeb6fc530188f43f648a357edceb2 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/android/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/android/__pycache__/android_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/android/__pycache__/android_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5332902c8c9915e23cbcbd80a16c509c8d9b1e86 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/android/__pycache__/android_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/android/__pycache__/android_ext_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/android/__pycache__/android_ext_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79cd8b2747a863d27219e0ac430ad46589ff82ec Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/android/__pycache__/android_ext_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/android/__pycache__/const.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/android/__pycache__/const.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b4e2cac20c9b6c77c76348a5fd1953bdbb0e3d6 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/android/__pycache__/const.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/android/__pycache__/env_space.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/android/__pycache__/env_space.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b82762dc8049fcde167827da3c0656e28cb42c6f Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/android/__pycache__/env_space.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/android/android_env.py b/notebook_dir/metagpt_yusin/environment/android/android_env.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d6bd9f1108b3e2d52cd031f5790f7caaab8a58 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/android/android_env.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Android Env + +from pydantic import Field + +from metagpt_yusin.environment.android.android_ext_env import AndroidExtEnv +from metagpt_yusin.environment.base_env import Environment + + +class AndroidEnv(AndroidExtEnv, Environment): + """in order to use actual `reset`&`observe`, inherited order: AndroidExtEnv, Environment""" + + rows: int = Field(default=0, description="rows of a grid on the screenshot") + cols: int = Field(default=0, description="cols of a grid on the screenshot") diff --git a/notebook_dir/metagpt_yusin/environment/android/android_ext_env.py b/notebook_dir/metagpt_yusin/environment/android/android_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..f1638e2fe1e4c31558930120fc8bb4a67bdcade8 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/android/android_ext_env.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The Android external environment to integrate with Android apps + +import subprocess +from pathlib import Path +from typing import Any, Optional + +from pydantic import Field + +from metagpt_yusin.environment.android.const import ADB_EXEC_FAIL +from metagpt_yusin.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, + EnvObsValType, +) +from metagpt_yusin.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable + + +class AndroidExtEnv(ExtEnv): + device_id: Optional[str] = Field(default=None) + screenshot_dir: Optional[Path] = Field(default=None) + xml_dir: Optional[Path] = Field(default=None) + width: int = Field(default=720, description="device screen width") + height: int = Field(default=1080, description="device screen height") + + def __init__(self, **data: Any): + super().__init__(**data) + device_id = data.get("device_id") + if device_id: + devices = self.list_devices() + if device_id not in devices: + raise RuntimeError(f"device-id: {device_id} not found") + (width, height) = self.device_shape + self.width = data.get("width", width) + self.height = data.get("height", height) + + self.create_device_path(self.screenshot_dir) + self.create_device_path(self.xml_dir) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + super().reset(seed=seed, options=options) + + obs = self._get_obs() + + return obs, {} + + def _get_obs(self) -> dict[str, EnvObsValType]: + pass + + def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any: + obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE + if obs_type == EnvObsType.NONE: + pass + elif obs_type == EnvObsType.GET_SCREENSHOT: + obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir) + elif obs_type == EnvObsType.GET_XML: + obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir) + return obs + + def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + res = self._execute_env_action(action) + + obs = {} + + ret = (obs, 1.0, False, False, {"res": res}) + return ret + + def _execute_env_action(self, action: EnvAction): + action_type = action.action_type + res = None + if action_type == EnvActionType.NONE: + pass + elif action_type == EnvActionType.SYSTEM_BACK: + res = self.system_back() + elif action_type == EnvActionType.SYSTEM_TAP: + res = self.system_tap(x=action.coord[0], y=action.coord[1]) + elif action_type == EnvActionType.USER_INPUT: + res = self.user_input(input_txt=action.input_txt) + elif action_type == EnvActionType.USER_LONGPRESS: + res = self.user_longpress(x=action.coord[0], y=action.coord[1]) + elif action_type == EnvActionType.USER_SWIPE: + res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist) + elif action_type == EnvActionType.USER_SWIPE_TO: + res = self.user_swipe_to(start=action.coord, end=action.tgt_coord) + return res + + @property + def adb_prefix_si(self): + """adb cmd prefix with `device_id` and `shell input`""" + return f"adb -s {self.device_id} shell input " + + @property + def adb_prefix_shell(self): + """adb cmd prefix with `device_id` and `shell`""" + return f"adb -s {self.device_id} shell " + + @property + def adb_prefix(self): + """adb cmd prefix with `device_id`""" + return f"adb -s {self.device_id} " + + def execute_adb_with_cmd(self, adb_cmd: str) -> str: + adb_cmd = adb_cmd.replace("\\", "/") + res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + exec_res = ADB_EXEC_FAIL + if not res.returncode: + exec_res = res.stdout.strip() + return exec_res + + def create_device_path(self, folder_path: Path): + adb_cmd = f"{self.adb_prefix_shell} mkdir {folder_path} -p" + res = self.execute_adb_with_cmd(adb_cmd) + if res == ADB_EXEC_FAIL: + raise RuntimeError(f"create device path: {folder_path} failed") + + @property + def device_shape(self) -> tuple[int, int]: + adb_cmd = f"{self.adb_prefix_shell} wm size" + shape = (0, 0) + shape_res = self.execute_adb_with_cmd(adb_cmd) + if shape_res != ADB_EXEC_FAIL: + shape = tuple(map(int, shape_res.split(": ")[1].split("x"))) + return shape + + def list_devices(self): + adb_cmd = "adb devices" + res = self.execute_adb_with_cmd(adb_cmd) + devices = [] + if res != ADB_EXEC_FAIL: + devices = res.split("\n")[1:] + devices = [device.split()[0] for device in devices] + return devices + + @mark_as_readable + def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path: + """ + ss_name: screenshot file name + local_save_dir: local dir to store image from virtual machine + """ + assert self.screenshot_dir + ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png") + ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}" + ss_res = self.execute_adb_with_cmd(ss_cmd) + + res = ADB_EXEC_FAIL + if ss_res != ADB_EXEC_FAIL: + ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png") + pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}" + pull_res = self.execute_adb_with_cmd(pull_cmd) + if pull_res != ADB_EXEC_FAIL: + res = ss_local_path + return Path(res) + + @mark_as_readable + def get_xml(self, xml_name: str, local_save_dir: Path) -> Path: + xml_remote_path = Path(self.xml_dir).joinpath(f"{xml_name}.xml") + dump_cmd = f"{self.adb_prefix_shell} uiautomator dump {xml_remote_path}" + xml_res = self.execute_adb_with_cmd(dump_cmd) + + res = ADB_EXEC_FAIL + if xml_res != ADB_EXEC_FAIL: + xml_local_path = Path(local_save_dir).joinpath(f"{xml_name}.xml") + pull_cmd = f"{self.adb_prefix} pull {xml_remote_path} {xml_local_path}" + pull_res = self.execute_adb_with_cmd(pull_cmd) + if pull_res != ADB_EXEC_FAIL: + res = xml_local_path + return Path(res) + + @mark_as_writeable + def system_back(self) -> str: + adb_cmd = f"{self.adb_prefix_si} keyevent KEYCODE_BACK" + back_res = self.execute_adb_with_cmd(adb_cmd) + return back_res + + @mark_as_writeable + def system_tap(self, x: int, y: int) -> str: + adb_cmd = f"{self.adb_prefix_si} tap {x} {y}" + tap_res = self.execute_adb_with_cmd(adb_cmd) + return tap_res + + @mark_as_writeable + def user_input(self, input_txt: str) -> str: + input_txt = input_txt.replace(" ", "%s").replace("'", "") + adb_cmd = f"{self.adb_prefix_si} text {input_txt}" + input_res = self.execute_adb_with_cmd(adb_cmd) + return input_res + + @mark_as_writeable + def user_longpress(self, x: int, y: int, duration: int = 500) -> str: + adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x} {y} {duration}" + press_res = self.execute_adb_with_cmd(adb_cmd) + return press_res + + @mark_as_writeable + def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", if_quick: bool = False) -> str: + dist_unit = int(self.width / 10) + if dist == "long": + dist_unit *= 3 + elif dist == "medium": + dist_unit *= 2 + + if orient == "up": + offset = 0, -2 * dist_unit + elif orient == "down": + offset = 0, 2 * dist_unit + elif orient == "left": + offset = -1 * dist_unit, 0 + elif orient == "right": + offset = dist_unit, 0 + else: + return ADB_EXEC_FAIL + + duration = 100 if if_quick else 400 + adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x + offset[0]} {y + offset[1]} {duration}" + swipe_res = self.execute_adb_with_cmd(adb_cmd) + return swipe_res + + @mark_as_writeable + def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400): + adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}" + swipe_res = self.execute_adb_with_cmd(adb_cmd) + return swipe_res diff --git a/notebook_dir/metagpt_yusin/environment/android/const.py b/notebook_dir/metagpt_yusin/environment/android/const.py new file mode 100644 index 0000000000000000000000000000000000000000..8811289bf097f478c89f0ab07cfb8aa55d20e7a6 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/android/const.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +# For Android Assistant Agent +ADB_EXEC_FAIL = "FAILED" diff --git a/notebook_dir/metagpt_yusin/environment/android/env_space.py b/notebook_dir/metagpt_yusin/environment/android/env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..604c40e1247dd2211a32c9adab9ab1edf8e0f73f --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/android/env_space.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pathlib import Path +from typing import Union + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt_yusin.environment.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + SYSTEM_BACK = 1 + SYSTEM_TAP = 2 + USER_INPUT = 3 + USER_LONGPRESS = 4 + USER_SWIPE = 5 + USER_SWIPE_TO = 6 + + +class EnvAction(BaseEnvAction): + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate" + ) + tgt_coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate" + ) + input_txt: str = Field(default="", description="user input text") + orient: str = Field(default="up", description="swipe orient") + dist: str = Field(default="medium", description="swipe dist") + + @field_validator("coord", "tgt_coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + NONE = 0 # get whole observation from env + + GET_SCREENSHOT = 1 + GET_XML = 2 + + +class EnvObsParams(BaseEnvObsParams): + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + ss_name: str = Field(default="", description="screenshot file name") + xml_name: str = Field(default="", description="xml file name") + local_save_dir: Union[str, Path] = Field(default="", description="local dir to save file") + + +EnvObsValType = str + + +def get_observation_space() -> spaces.Dict: + space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)}) + return space + + +def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict: + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "tgt_coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "input_txt": spaces.Text(256), + "orient": spaces.Text(16), + "dist": spaces.Text(16), + } + ) + return space diff --git a/notebook_dir/metagpt_yusin/environment/api/__init__.py b/notebook_dir/metagpt_yusin/environment/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/api/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/environment/api/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/api/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26cd28006e275fbee12e33b21ba8217335f1a7e0 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/api/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/api/__pycache__/env_api.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/api/__pycache__/env_api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0551560e512275ebe5e1fae823e7c2901c67211 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/api/__pycache__/env_api.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/api/env_api.py b/notebook_dir/metagpt_yusin/environment/api/env_api.py new file mode 100644 index 0000000000000000000000000000000000000000..924f6b1041eee9b87e0dd4b144e2adecc5626728 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/api/env_api.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the environment api store + +from typing import Any, Callable, Union + +from pydantic import BaseModel, Field + + +class EnvAPIAbstract(BaseModel): + """api/interface summary description""" + + api_name: str = Field(default="", description="the api function name or id") + args: set = Field(default={}, description="the api function `args` params") + kwargs: dict = Field(default=dict(), description="the api function `kwargs` params") + + +class EnvAPIRegistry(BaseModel): + """the registry to store environment w&r api/interface""" + + registry: dict[str, Callable] = Field(default=dict(), exclude=True) + + def get(self, api_name: str): + if api_name not in self.registry: + raise KeyError(f"api_name: {api_name} not found") + return self.registry.get(api_name) + + def __getitem__(self, api_name: str) -> Callable: + return self.get(api_name) + + def __setitem__(self, api_name: str, func: Callable): + self.registry[api_name] = func + + def __len__(self): + return len(self.registry) + + def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]: + """return func schema without func instance""" + apis = dict() + for func_name, func_schema in self.registry.items(): + new_func_schema = dict() + for key, value in func_schema.items(): + if key == "func": + continue + new_func_schema[key] = str(value) if as_str else value + new_func_schema = new_func_schema + apis[func_name] = new_func_schema + return apis + + +class WriteAPIRegistry(EnvAPIRegistry): + """just as a explicit class name""" + + pass + + +class ReadAPIRegistry(EnvAPIRegistry): + """just as a explicit class name""" + + pass diff --git a/notebook_dir/metagpt_yusin/environment/base_env.py b/notebook_dir/metagpt_yusin/environment/base_env.py new file mode 100644 index 0000000000000000000000000000000000000000..cd1a11d8e72d3797ec2cc72f44f458e997df0480 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/base_env.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base env of executing environment + +import asyncio +from abc import abstractmethod +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Set, Union + +from gymnasium import spaces +from gymnasium.core import ActType, ObsType +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator + +from metagpt_yusin.context import Context +from metagpt_yusin.environment.api.env_api import ( + EnvAPIAbstract, + ReadAPIRegistry, + WriteAPIRegistry, +) +from metagpt_yusin.environment.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message +from metagpt_yusin.utils.common import get_function_schema, is_coroutine_func, is_send_to + +if TYPE_CHECKING: + from metagpt_yusin.roles.role import Role # noqa: F401 + + +class EnvType(Enum): + ANDROID = "Android" + GYM = "Gym" + WEREWOLF = "Werewolf" + MINECRAFT = "Minecraft" + STANFORDTOWN = "StanfordTown" + + +env_write_api_registry = WriteAPIRegistry() +env_read_api_registry = ReadAPIRegistry() + + +def mark_as_readable(func): + """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" + env_read_api_registry[func.__name__] = get_function_schema(func) + return func + + +def mark_as_writeable(func): + """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" + env_write_api_registry[func.__name__] = get_function_schema(func) + return func + + +class ExtEnv(BaseModel): + """External Env to integrate actual game environment""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True) + observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True) + + def _check_api_exist(self, rw_api: Optional[str] = None): + if not rw_api: + raise ValueError(f"{rw_api} not exists") + + def get_all_available_apis(self, mode: str = "read") -> list[Any]: + """get available read/write apis definition""" + assert mode in ["read", "write"] + if mode == "read": + return env_read_api_registry.get_apis() + else: + return env_write_api_registry.get_apis() + + async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]): + """get observation from particular api of ExtEnv""" + if isinstance(env_action, str): + env_read_api = env_read_api_registry.get(api_name=env_action)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self) + else: + res = env_read_api(self) + elif isinstance(env_action, EnvAPIAbstract): + env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self, *env_action.args, **env_action.kwargs) + else: + res = env_read_api(self, *env_action.args, **env_action.kwargs) + return res + + async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + """execute through particular api of ExtEnv""" + res = None + if isinstance(env_action, Message): + self.publish_message(env_action) + elif isinstance(env_action, EnvAPIAbstract): + env_write_api = env_write_api_registry.get(env_action.api_name)["func"] + self._check_api_exist(env_write_api) + if is_coroutine_func(env_write_api): + res = await env_write_api(self, *env_action.args, **env_action.kwargs) + else: + res = env_write_api(self, *env_action.args, **env_action.kwargs) + + return res + + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Implement this to get init observation""" + + @abstractmethod + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """Implement this if you want to get partial observation from the env""" + + @abstractmethod + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """Implement this to feed a action and then get new observation from the env""" + + +class Environment(ExtEnv): + """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 + Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + desc: str = Field(default="") # 环境描述 + roles: dict[str, SerializeAsAny["Role"]] = Field(default_factory=dict, validate_default=True) + member_addrs: Dict["Role", Set] = Field(default_factory=dict, exclude=True) + history: str = "" # For debug + context: Context = Field(default_factory=Context, exclude=True) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + + @model_validator(mode="after") + def init_roles(self): + self.add_roles(self.roles.values()) + return self + + def add_role(self, role: "Role"): + """增加一个在当前环境的角色 + Add a role in the current environment + """ + self.roles[role.profile] = role + role.set_env(self) + role.context = self.context + + def add_roles(self, roles: Iterable["Role"]): + """增加一批在当前环境的角色 + Add a batch of characters in the current environment + """ + for role in roles: + self.roles[role.profile] = role + + for role in roles: # setup system message with roles + role.context = self.context + role.set_env(self) + + def publish_message(self, message: Message, peekable: bool = True) -> bool: + """ + Distribute the message to the recipients. + In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned + in RFC 113 for the entire system, the routing information in the Message is only responsible for + specifying the message recipient, without concern for where the message recipient is located. How to + route the message to the message recipient is a problem addressed by the transport framework designed + in RFC 113. + """ + logger.debug(f"publish_message: {message.dump()}") + found = False + # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 + for role, addrs in self.member_addrs.items(): + if is_send_to(message, addrs): + role.put_message(message) + found = True + if not found: + logger.warning(f"Message no recipients: {message.dump()}") + self.history += f"\n{message}" # For debug + + return True + + async def run(self, k=1): + """处理一次所有信息的运行 + Process all Role runs at once + """ + for _ in range(k): + futures = [] + for role in self.roles.values(): + future = role.run() + futures.append(future) + + await asyncio.gather(*futures) + logger.debug(f"is idle: {self.is_idle}") + + def get_roles(self) -> dict[str, "Role"]: + """获得环境内的所有角色 + Process all Role runs at once + """ + return self.roles + + def get_role(self, name: str) -> "Role": + """获得环境内的指定角色 + get all the environment roles + """ + return self.roles.get(name, None) + + def role_names(self) -> list[str]: + return [i.name for i in self.roles.values()] + + @property + def is_idle(self): + """If true, all actions have been executed.""" + for r in self.roles.values(): + if not r.is_idle: + return False + return True + + def get_addresses(self, obj): + """Get the addresses of the object.""" + return self.member_addrs.get(obj, {}) + + def set_addresses(self, obj, addresses): + """Set the addresses of the object""" + self.member_addrs[obj] = addresses + + def archive(self, auto_archive=True): + if auto_archive and self.context.git_repo: + self.context.git_repo.archive() + + @classmethod + def model_rebuild(cls, **kwargs): + from metagpt_yusin.roles.role import Role # noqa: F401 + + super().model_rebuild(**kwargs) + + +Environment.model_rebuild() diff --git a/notebook_dir/metagpt_yusin/environment/base_env_space.py b/notebook_dir/metagpt_yusin/environment/base_env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..fd0cfa399f00298d904d88982ea56c1008b1d1b2 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/base_env_space.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from enum import IntEnum + +from pydantic import BaseModel, ConfigDict, Field + + +class BaseEnvActionType(IntEnum): + # # NONE = 0 # no action to run, just get observation + pass + + +class BaseEnvAction(BaseModel): + """env action type and its related params of action functions/apis""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=0, description="action type") + + +class BaseEnvObsType(IntEnum): + # # NONE = 0 # get whole observation from env + pass + + +class BaseEnvObsParams(BaseModel): + """observation params for different EnvObsType to get its observe result""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=0, description="observation type") diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/__init__.py b/notebook_dir/metagpt_yusin/environment/minecraft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/const.py b/notebook_dir/metagpt_yusin/environment/minecraft/const.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0ffd6d933d76a1c94160b6623e72d8de095e63 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/const.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt_yusin.const import metagpt_yusin_ROOT + +# For Minecraft Game Agent +MC_CKPT_DIR = metagpt_yusin_ROOT / "data/minecraft/ckpt" +MC_LOG_DIR = metagpt_yusin_ROOT / "logs" +MC_DEFAULT_WARMUP = { + "context": 15, + "biome": 10, + "time": 15, + "nearby_blocks": 0, + "other_blocks": 10, + "nearby_entities": 5, + "health": 15, + "hunger": 15, + "position": 0, + "equipment": 0, + "inventory": 0, + "optional_inventory_items": 7, + "chests": 0, + "completed_tasks": 0, + "failed_tasks": 0, +} +MC_CURRICULUM_OB = [ + "context", + "biome", + "time", + "nearby_blocks", + "other_blocks", + "nearby_entities", + "health", + "hunger", + "position", + "equipment", + "inventory", + "chests", + "completed_tasks", + "failed_tasks", +] +MC_CORE_INVENTORY_ITEMS = r".*_log|.*_planks|stick|crafting_table|furnace" +r"|cobblestone|dirt|coal|.*_pickaxe|.*_sword|.*_axe", # curriculum_agent: only show these items in inventory before optional_inventory_items reached in warm up diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/minecraft_env.py b/notebook_dir/metagpt_yusin/environment/minecraft/minecraft_env.py new file mode 100644 index 0000000000000000000000000000000000000000..68ca6729bf3070a7e3833c128c61cd032680eb4a --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/minecraft_env.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Minecraft Env +# refs to `voyager voyager.py` + +import json +import re +import time +from typing import Any, Iterable + +from llama_index.vector_stores.chroma import ChromaVectorStore +from pydantic import ConfigDict, Field + +from metagpt_yusin.config2 import config as CONFIG +from metagpt_yusin.environment.base_env import Environment +from metagpt_yusin.environment.minecraft.const import MC_CKPT_DIR +from metagpt_yusin.environment.minecraft.minecraft_ext_env import MinecraftExtEnv +from metagpt_yusin.logs import logger +from metagpt_yusin.utils.common import load_mc_skills_code, read_json_file, write_json_file + + +class MinecraftEnv(Environment, MinecraftExtEnv): + """MinecraftEnv, including shared memory of cache and information between roles""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + event: dict[str, Any] = Field(default_factory=dict) + current_task: str = Field(default="Mine 1 wood log") + task_execution_time: float = Field(default=float) + context: str = Field(default="You can mine one of oak, birch, spruce, jungle, acacia, dark oak, or mangrove logs.") + code: str = Field(default="") + program_code: str = Field(default="") # write in skill/code/*.js + program_name: str = Field(default="") + critique: str = Field(default="") + skills: dict = Field(default_factory=dict) # for skills.json + retrieve_skills: list[str] = Field(default_factory=list) + event_summary: str = Field(default="") + + qa_cache: dict[str, str] = Field(default_factory=dict) + completed_tasks: list[str] = Field(default_factory=list) # Critique things + failed_tasks: list[str] = Field(default_factory=list) + + skill_desp: str = Field(default="") + + chest_memory: dict[str, Any] = Field(default_factory=dict) # eg: {'(1344, 64, 1381)': 'Unknown'} + chest_observation: str = Field(default="") # eg: "Chests: None\n\n" + + runtime_status: bool = False # equal to action execution status: success or failed + + vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore) + + qa_cache_questions_vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore) + + @property + def progress(self): + # return len(self.completed_tasks) + 10 # Test only + return len(self.completed_tasks) + + @property + def programs(self): + programs = "" + if self.code == "": + return programs # TODO: maybe fix 10054 now, a better way is isolating env.step() like voyager + for skill_name, entry in self.skills.items(): + programs += f"{entry['code']}\n\n" + for primitives in load_mc_skills_code(): # TODO add skills_dir + programs += f"{primitives}\n\n" + return programs + + def set_mc_port(self, mc_port): + super().set_mc_port(mc_port) + self.set_mc_resume() + + def set_mc_resume(self): + self.qa_cache_questions_vectordb = ChromaVectorStore( + collection_name="qa_cache_questions_vectordb", + persist_dir=f"{MC_CKPT_DIR}/curriculum/vectordb", + ) + + self.vectordb = ChromaVectorStore( + collection_name="skill_vectordb", + persist_dir=f"{MC_CKPT_DIR}/skill/vectordb", + ) + + if CONFIG.resume: + logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action") + self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json") + + logger.info(f"Loading Curriculum Agent from {MC_CKPT_DIR}/curriculum") + self.completed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json") + self.failed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json") + + logger.info(f"Loading Skill Manager from {MC_CKPT_DIR}/skill\033[0m") + self.skills = read_json_file(f"{MC_CKPT_DIR}/skill/skills.json") + + logger.info(f"Loading Qa Cache from {MC_CKPT_DIR}/curriculum\033[0m") + self.qa_cache = read_json_file(f"{MC_CKPT_DIR}/curriculum/qa_cache.json") + + if self.vectordb._collection.count() == 0: + logger.info(self.vectordb._collection.count()) + # Set vdvs for skills & qa_cache + skill_desps = [skill["description"] for program_name, skill in self.skills.items()] + program_names = [program_name for program_name, skill in self.skills.items()] + metadatas = [{"name": program_name} for program_name in program_names] + # add vectordb from file + self.vectordb.add_texts( + texts=skill_desps, + ids=program_names, + metadatas=metadatas, + ) + self.vectordb.persist() + + logger.info(self.qa_cache_questions_vectordb._collection.count()) + if self.qa_cache_questions_vectordb._collection.count() == 0: + questions = [question for question, answer in self.qa_cache.items()] + + self.qa_cache_questions_vectordb.add_texts(texts=questions) + + self.qa_cache_questions_vectordb.persist() + + logger.info( + f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json." + ) + # Check if Skill Manager's vectordb right using + assert self.vectordb._collection.count() == len(self.skills), ( + f"Skill Manager's vectordb is not synced with skills.json.\n" + f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n" + f"Did you set resume=False when initializing the manager?\n" + f"You may need to manually delete the vectordb directory for running from scratch." + ) + + logger.info( + f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json." + ) + assert self.qa_cache_questions_vectordb._collection.count() == len(self.qa_cache), ( + f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n" + f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb " + f"but {len(self.qa_cache)} questions in qa_cache.json.\n" + f"Did you set resume=False when initializing the agent?\n" + f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n" + ) + + def register_roles(self, roles: Iterable["Minecraft"]): + for role in roles: + role.set_memory(self) + + def update_event(self, event: dict): + if self.event == event: + return + self.event = event + self.update_chest_memory(event) + self.update_chest_observation() + # self.event_summary = self.summarize_chatlog(event) + + def update_task(self, task: str): + self.current_task = task + + def update_context(self, context: str): + self.context = context + + def update_program_code(self, program_code: str): + self.program_code = program_code + + def update_code(self, code: str): + self.code = code # action_developer.gen_action_code to HERE + + def update_program_name(self, program_name: str): + self.program_name = program_name + + def update_critique(self, critique: str): + self.critique = critique # critic_agent.check_task_success to HERE + + def append_skill(self, skill: dict): + self.skills[self.program_name] = skill # skill_manager.retrieve_skills to HERE + + def update_retrieve_skills(self, retrieve_skills: list): + self.retrieve_skills = retrieve_skills + + def update_skill_desp(self, skill_desp: str): + self.skill_desp = skill_desp + + async def update_qa_cache(self, qa_cache: dict): + self.qa_cache = qa_cache + + def update_chest_memory(self, events: dict): + """ + Input: events: Dict + Result: self.chest_memory update & save to json + """ + nearbyChests = events[-1][1]["nearbyChests"] + for position, chest in nearbyChests.items(): + if position in self.chest_memory: + if isinstance(chest, dict): + self.chest_memory[position] = chest + if chest == "Invalid": + logger.info(f"Action Developer removing chest {position}: {chest}") + self.chest_memory.pop(position) + else: + if chest != "Invalid": + logger.info(f"Action Developer saving chest {position}: {chest}") + self.chest_memory[position] = chest + + write_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json", self.chest_memory) + + def update_chest_observation(self): + """ + update chest_memory to chest_observation. + Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py + """ + + chests = [] + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, dict) and len(chest) > 0: + chests.append(f"{chest_position}: {chest}") + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, dict) and len(chest) == 0: + chests.append(f"{chest_position}: Empty") + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, str): + assert chest == "Unknown" + chests.append(f"{chest_position}: Unknown items inside") + assert len(chests) == len(self.chest_memory) + if chests: + chests = "\n".join(chests) + self.chest_observation = f"Chests:\n{chests}\n\n" + else: + self.chest_observation = "Chests: None\n\n" + + def summarize_chatlog(self, events): + def filter_item(message: str): + craft_pattern = r"I cannot make \w+ because I need: (.*)" + craft_pattern2 = r"I cannot make \w+ because there is no crafting table nearby" + mine_pattern = r"I need at least a (.*) to mine \w+!" + if re.match(craft_pattern, message): + self.event_summary = re.match(craft_pattern, message).groups()[0] + elif re.match(craft_pattern2, message): + self.event_summary = "a nearby crafting table" + elif re.match(mine_pattern, message): + self.event_summary = re.match(mine_pattern, message).groups()[0] + else: + self.event_summary = "" + return self.event_summary + + chatlog = set() + for event_type, event in events: + if event_type == "onChat": + item = filter_item(event["onChat"]) + if item: + chatlog.add(item) + self.event_summary = "I also need " + ", ".join(chatlog) + "." if chatlog else "" + + def reset_block_info(self): + # revert all the placing event in the last step + pass + + def update_exploration_progress(self, success: bool): + """ + Split task into completed_tasks or failed_tasks + Args: info = { + "task": self.task, + "success": success, + "conversations": self.conversations, + } + """ + self.runtime_status = success + task = self.current_task + if task.startswith("Deposit useless items into the chest at"): + return + if success: + logger.info(f"Completed task {task}.") + self.completed_tasks.append(task) + else: + logger.info(f"Failed to complete task {task}. Skipping to next task.") + self.failed_tasks.append(task) + # when not success, below to update event! + # revert all the placing event in the last step + blocks = [] + positions = [] + for event_type, event in self.event: + if event_type == "onSave" and event["onSave"].endswith("_placed"): + block = event["onSave"].split("_placed")[0] + position = event["status"]["position"] + blocks.append(block) + positions.append(position) + new_events = self._step( + f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})", + programs=self.programs, + ) + self.event[-1][1]["inventory"] = new_events[-1][1]["inventory"] + self.event[-1][1]["voxels"] = new_events[-1][1]["voxels"] + + self.save_sorted_tasks() + + def save_sorted_tasks(self): + updated_completed_tasks = [] + # record repeated failed tasks + updated_failed_tasks = self.failed_tasks + # dedup but keep order + for task in self.completed_tasks: + if task not in updated_completed_tasks: + updated_completed_tasks.append(task) + + # remove completed tasks from failed tasks + for task in updated_completed_tasks: + while task in updated_failed_tasks: + updated_failed_tasks.remove(task) + + self.completed_tasks = updated_completed_tasks + self.failed_tasks = updated_failed_tasks + + # dump to json + write_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json", self.completed_tasks) + write_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json", self.failed_tasks) + + async def on_event_retrieve(self, *args): + """ + Retrieve Minecraft events. + + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + self._reset( + options={ + "mode": "soft", + "wait_ticks": 20, + } + ) + # difficulty = "easy" if len(self.completed_tasks) > 15 else "peaceful" + difficulty = "peaceful" + + events = self._step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');") + self.update_event(events) + return events + except Exception as e: + time.sleep(3) # wait for mineflayer to exit + # reset bot status here + events = self._reset( + options={ + "mode": "hard", + "wait_ticks": 20, + "inventory": self.event[-1][1]["inventory"], + "equipment": self.event[-1][1]["status"]["equipment"], + "position": self.event[-1][1]["status"]["position"], + } + ) + self.update_event(events) + logger.error(f"Failed to retrieve Minecraft events: {str(e)}") + return events + + async def on_event_execute(self, *args): + """ + Execute Minecraft events. + + This function is used to obtain events from the Minecraft environment. Check the implementation in + the 'voyager/env/bridge.py step()' function to capture events generated within the game. + + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + events = self._step( + code=self.code, + programs=self.programs, + ) + self.update_event(events) + return events + except Exception as e: + time.sleep(3) # wait for mineflayer to exit + # reset bot status here + events = self._reset( + options={ + "mode": "hard", + "wait_ticks": 20, + "inventory": self.event[-1][1]["inventory"], + "equipment": self.event[-1][1]["status"]["equipment"], + "position": self.event[-1][1]["status"]["position"], + } + ) + self.update_event(events) + logger.error(f"Failed to execute Minecraft events: {str(e)}") + return events diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/minecraft_ext_env.py b/notebook_dir/metagpt_yusin/environment/minecraft/minecraft_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..851d3450b385d7da0ba25362a1698e8f2ebc56e0 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/minecraft_ext_env.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The Minecraft external environment to integrate with Minecraft game +# refs to `voyager bridge.py` + +import json +import time +from typing import Any, Optional + +import requests +from pydantic import ConfigDict, Field, model_validator + +from metagpt_yusin.environment.base_env import ExtEnv, mark_as_writeable +from metagpt_yusin.environment.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt_yusin.environment.minecraft.const import ( + MC_CKPT_DIR, + MC_CORE_INVENTORY_ITEMS, + MC_CURRICULUM_OB, + MC_DEFAULT_WARMUP, + metagpt_yusin_ROOT, +) +from metagpt_yusin.environment.minecraft.process_monitor import SubprocessMonitor +from metagpt_yusin.logs import logger + + +class MinecraftExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + mc_port: Optional[int] = Field(default=None) + server_host: str = Field(default="http://127.0.0.1") + server_port: str = Field(default=3000) + request_timeout: int = Field(default=600) + + mineflayer: Optional[SubprocessMonitor] = Field(default=None, validate_default=True) + + has_reset: bool = Field(default=False) + reset_options: Optional[dict] = Field(default=None) + connected: bool = Field(default=False) + server_paused: bool = Field(default=False) + warm_up: dict = Field(default=dict()) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + + @property + def server(self) -> str: + return f"{self.server_host}:{self.server_port}" + + @model_validator(mode="after") + def _post_init_ext_env(self): + if not self.mineflayer: + self.mineflayer = SubprocessMonitor( + commands=[ + "node", + metagpt_yusin_ROOT.joinpath("metagpt_yusin", "environment", "minecraft", "mineflayer", "index.js"), + str(self.server_port), + ], + name="mineflayer", + ready_match=r"Server started on port (\d+)", + ) + if not self.warm_up: + warm_up = MC_DEFAULT_WARMUP + if "optional_inventory_items" in warm_up: + assert MC_CORE_INVENTORY_ITEMS is not None + # self.core_inv_items_regex = re.compile(MC_CORE_INVENTORY_ITEMS) + self.warm_up["optional_inventory_items"] = warm_up["optional_inventory_items"] + else: + self.warm_up["optional_inventory_items"] = 0 + for key in MC_CURRICULUM_OB: + self.warm_up[key] = warm_up.get(key, MC_DEFAULT_WARMUP[key]) + self.warm_up["nearby_blocks"] = 0 + self.warm_up["inventory"] = 0 + self.warm_up["completed_tasks"] = 0 + self.warm_up["failed_tasks"] = 0 + + # init ckpt sub-forders + MC_CKPT_DIR.joinpath("curriculum/vectordb").mkdir(parents=True, exist_ok=True) + MC_CKPT_DIR.joinpath("action").mkdir(exist_ok=True) + MC_CKPT_DIR.joinpath("skill/code").mkdir(parents=True, exist_ok=True) + MC_CKPT_DIR.joinpath("skill/description").mkdir(exist_ok=True) + MC_CKPT_DIR.joinpath("skill/vectordb").mkdir(exist_ok=True) + + def set_mc_port(self, mc_port: int): + self.mc_port = mc_port + + @mark_as_writeable + def close(self) -> bool: + self.unpause() + if self.connected: + res = requests.post(f"{self.server}/stop") + if res.status_code == 200: + self.connected = False + self.mineflayer.stop() + return not self.connected + + @mark_as_writeable + def check_process(self) -> dict: + retry = 0 + while not self.mineflayer.is_running: + logger.info("Mineflayer process has exited, restarting") + self.mineflayer.run() + if not self.mineflayer.is_running: + if retry > 3: + logger.error("Mineflayer process failed to start") + raise {} + else: + retry += 1 + continue + logger.info(self.mineflayer.ready_line) + res = requests.post( + f"{self.server}/start", + json=self.reset_options, + timeout=self.request_timeout, + ) + if res.status_code != 200: + self.mineflayer.stop() + logger.error(f"Minecraft server reply with code {res.status_code}") + raise {} + return res.json() + + @mark_as_writeable + def _reset(self, *, seed=None, options=None) -> dict: + if options is None: + options = {} + if options.get("inventory", {}) and options.get("mode", "hard") != "hard": + logger.error("inventory can only be set when options is hard") + raise {} + + self.reset_options = { + "port": self.mc_port, + "reset": options.get("mode", "hard"), + "inventory": options.get("inventory", {}), + "equipment": options.get("equipment", []), + "spread": options.get("spread", False), + "waitTicks": options.get("wait_ticks", 5), + "position": options.get("position", None), + } + + self.unpause() + self.mineflayer.stop() + time.sleep(1) # wait for mineflayer to exit + + returned_data = self.check_process() + self.has_reset = True + self.connected = True + # All the reset in step will be soft + self.reset_options["reset"] = "soft" + self.pause() + return json.loads(returned_data) + + @mark_as_writeable + def _step(self, code: str, programs: str = "") -> dict: + if not self.has_reset: + raise RuntimeError("Environment has not been reset yet") + self.check_process() + self.unpause() + data = { + "code": code, + "programs": programs, + } + res = requests.post(f"{self.server}/step", json=data, timeout=self.request_timeout) + if res.status_code != 200: + raise RuntimeError("Failed to step Minecraft server") + returned_data = res.json() + self.pause() + return json.loads(returned_data) + + @mark_as_writeable + def pause(self) -> bool: + if self.mineflayer.is_running and not self.server_paused: + res = requests.post(f"{self.server}/pause") + if res.status_code == 200: + self.server_paused = True + return self.server_paused + + @mark_as_writeable + def unpause(self) -> bool: + if self.mineflayer.is_running and self.server_paused: + res = requests.post(f"{self.server}/pause") + if res.status_code == 200: + self.server_paused = False + else: + logger.info(f"mineflayer pause result: {res.json()}") + return self.server_paused diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.gitignore b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0fd46841034a3366d29b838bd070304a78e31337 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.gitignore @@ -0,0 +1 @@ +!/lib \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.prettierignore b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.prettierignore new file mode 100644 index 0000000000000000000000000000000000000000..1b07c39e9b4cf3756f6e3ea23f7ab6ea22a87f15 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.prettierignore @@ -0,0 +1,3 @@ +# Ignore artifacts: +build +coverage \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.prettierrc.json b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.prettierrc.json new file mode 100644 index 0000000000000000000000000000000000000000..0a02bcefdab2e1654666e9d5effedc14501e98db --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/.prettierrc.json @@ -0,0 +1,3 @@ +{ + "tabWidth": 4 +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/index.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/index.js new file mode 100644 index 0000000000000000000000000000000000000000..7fb0a8787f87596b9be31818d022c8f0eb0d5951 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/index.js @@ -0,0 +1,425 @@ +const fs = require("fs"); +const express = require("express"); +const bodyParser = require("body-parser"); +const mineflayer = require("mineflayer"); + +const skills = require("./lib/skillLoader"); +const { initCounter, getNextTime } = require("./lib/utils"); +const obs = require("./lib/observation/base"); +const OnChat = require("./lib/observation/onChat"); +const OnError = require("./lib/observation/onError"); +const { Voxels, BlockRecords } = require("./lib/observation/voxels"); +const Status = require("./lib/observation/status"); +const Inventory = require("./lib/observation/inventory"); +const OnSave = require("./lib/observation/onSave"); +const Chests = require("./lib/observation/chests"); +const { plugin: tool } = require("mineflayer-tool"); + +let bot = null; + +const app = express(); + +app.use(bodyParser.json({ limit: "50mb" })); +app.use(bodyParser.urlencoded({ limit: "50mb", extended: false })); + +app.post("/start", (req, res) => { + if (bot) onDisconnect("Restarting bot"); + bot = null; + console.log(req.body); + bot = mineflayer.createBot({ + host: "localhost", // minecraft server ip + port: req.body.port, // minecraft server port + username: "bot", + disableChatSigning: true, + checkTimeoutInterval: 60 * 60 * 1000, + }); + bot.once("error", onConnectionFailed); + + // Event subscriptions + bot.waitTicks = req.body.waitTicks; + bot.globalTickCounter = 0; + bot.stuckTickCounter = 0; + bot.stuckPosList = []; + bot.iron_pickaxe = false; + + bot.on("kicked", onDisconnect); + + // mounting will cause physicsTick to stop + bot.on("mount", () => { + bot.dismount(); + }); + + bot.once("spawn", async () => { + bot.removeListener("error", onConnectionFailed); + let itemTicks = 1; + if (req.body.reset === "hard") { + bot.chat("/clear @s"); + bot.chat("/kill @s"); + const inventory = req.body.inventory ? req.body.inventory : {}; + const equipment = req.body.equipment + ? req.body.equipment + : [null, null, null, null, null, null]; + for (let key in inventory) { + bot.chat(`/give @s minecraft:${key} ${inventory[key]}`); + itemTicks += 1; + } + const equipmentNames = [ + "armor.head", + "armor.chest", + "armor.legs", + "armor.feet", + "weapon.mainhand", + "weapon.offhand", + ]; + for (let i = 0; i < 6; i++) { + if (i === 4) continue; + if (equipment[i]) { + bot.chat( + `/item replace entity @s ${equipmentNames[i]} with minecraft:${equipment[i]}` + ); + itemTicks += 1; + } + } + } + + if (req.body.position) { + bot.chat( + `/tp @s ${req.body.position.x} ${req.body.position.y} ${req.body.position.z}` + ); + } + + // if iron_pickaxe is in bot's inventory + if ( + bot.inventory.items().find((item) => item.name === "iron_pickaxe") + ) { + bot.iron_pickaxe = true; + } + + const { pathfinder } = require("mineflayer-pathfinder"); + const tool = require("mineflayer-tool").plugin; + const collectBlock = require("mineflayer-collectblock").plugin; + const pvp = require("mineflayer-pvp").plugin; + const minecraftHawkEye = require("minecrafthawkeye"); + bot.loadPlugin(pathfinder); + bot.loadPlugin(tool); + bot.loadPlugin(collectBlock); + bot.loadPlugin(pvp); + bot.loadPlugin(minecraftHawkEye); + + // bot.collectBlock.movements.digCost = 0; + // bot.collectBlock.movements.placeCost = 0; + + obs.inject(bot, [ + OnChat, + OnError, + Voxels, + Status, + Inventory, + OnSave, + Chests, + BlockRecords, + ]); + skills.inject(bot); + + if (req.body.spread) { + bot.chat(`/spreadplayers ~ ~ 0 300 under 80 false @s`); + await bot.waitForTicks(bot.waitTicks); + } + + await bot.waitForTicks(bot.waitTicks * itemTicks); + res.json(bot.observe()); + + initCounter(bot); + bot.chat("/gamerule keepInventory true"); + bot.chat("/gamerule doDaylightCycle false"); + }); + + function onConnectionFailed(e) { + console.log(e); + bot = null; + res.status(400).json({ error: e }); + } + function onDisconnect(message) { + if (bot.viewer) { + bot.viewer.close(); + } + bot.end(); + console.log(message); + bot = null; + } +}); + +app.post("/step", async (req, res) => { + // import useful package + let response_sent = false; + function otherError(err) { + console.log("Uncaught Error"); + bot.emit("error", handleError(err)); + bot.waitForTicks(bot.waitTicks).then(() => { + if (!response_sent) { + response_sent = true; + res.json(bot.observe()); + } + }); + } + + process.on("uncaughtException", otherError); + + const mcData = require("minecraft-data")(bot.version); + mcData.itemsByName["leather_cap"] = mcData.itemsByName["leather_helmet"]; + mcData.itemsByName["leather_tunic"] = + mcData.itemsByName["leather_chestplate"]; + mcData.itemsByName["leather_pants"] = + mcData.itemsByName["leather_leggings"]; + mcData.itemsByName["leather_boots"] = mcData.itemsByName["leather_boots"]; + mcData.itemsByName["lapis_lazuli_ore"] = mcData.itemsByName["lapis_ore"]; + mcData.blocksByName["lapis_lazuli_ore"] = mcData.blocksByName["lapis_ore"]; + const { + Movements, + goals: { + Goal, + GoalBlock, + GoalNear, + GoalXZ, + GoalNearXZ, + GoalY, + GoalGetToBlock, + GoalLookAtBlock, + GoalBreakBlock, + GoalCompositeAny, + GoalCompositeAll, + GoalInvert, + GoalFollow, + GoalPlaceBlock, + }, + pathfinder, + Move, + ComputedPath, + PartiallyComputedPath, + XZCoordinates, + XYZCoordinates, + SafeBlock, + GoalPlaceBlockOptions, + } = require("mineflayer-pathfinder"); + const { Vec3 } = require("vec3"); + + // Set up pathfinder + const movements = new Movements(bot, mcData); + bot.pathfinder.setMovements(movements); + + bot.globalTickCounter = 0; + bot.stuckTickCounter = 0; + bot.stuckPosList = []; + + function onTick() { + bot.globalTickCounter++; + if (bot.pathfinder.isMoving()) { + bot.stuckTickCounter++; + if (bot.stuckTickCounter >= 100) { + onStuck(1.5); + bot.stuckTickCounter = 0; + } + } + } + + bot.on("physicTick", onTick); + + // initialize fail count + let _craftItemFailCount = 0; + let _killMobFailCount = 0; + let _mineBlockFailCount = 0; + let _placeItemFailCount = 0; + let _smeltItemFailCount = 0; + + // Retrieve array form post bod + const code = req.body.code; + const programs = req.body.programs; + bot.cumulativeObs = []; + await bot.waitForTicks(bot.waitTicks); + const r = await evaluateCode(code, programs); + process.off("uncaughtException", otherError); + if (r !== "success") { + bot.emit("error", handleError(r)); + } + await returnItems(); + // wait for last message + await bot.waitForTicks(bot.waitTicks); + if (!response_sent) { + response_sent = true; + res.json(bot.observe()); + } + bot.removeListener("physicTick", onTick); + + async function evaluateCode(code, programs) { + // Echo the code produced for players to see it. Don't echo when the bot code is already producing dialog or it will double echo + try { + await eval("(async () => {" + programs + "\n" + code + "})()"); + return "success"; + } catch (err) { + return err; + } + } + + function onStuck(posThreshold) { + const currentPos = bot.entity.position; + bot.stuckPosList.push(currentPos); + + // Check if the list is full + if (bot.stuckPosList.length === 5) { + const oldestPos = bot.stuckPosList[0]; + const posDifference = currentPos.distanceTo(oldestPos); + + if (posDifference < posThreshold) { + teleportBot(); // execute the function + } + + // Remove the oldest time from the list + bot.stuckPosList.shift(); + } + } + + function teleportBot() { + const blocks = bot.findBlocks({ + matching: (block) => { + return block.type === 0; + }, + maxDistance: 1, + count: 27, + }); + + if (blocks) { + // console.log(blocks.length); + const randomIndex = Math.floor(Math.random() * blocks.length); + const block = blocks[randomIndex]; + bot.chat(`/tp @s ${block.x} ${block.y} ${block.z}`); + } else { + bot.chat("/tp @s ~ ~1.25 ~"); + } + } + + function returnItems() { + bot.chat("/gamerule doTileDrops false"); + const crafting_table = bot.findBlock({ + matching: mcData.blocksByName.crafting_table.id, + maxDistance: 128, + }); + if (crafting_table) { + bot.chat( + `/setblock ${crafting_table.position.x} ${crafting_table.position.y} ${crafting_table.position.z} air destroy` + ); + bot.chat("/give @s crafting_table"); + } + const furnace = bot.findBlock({ + matching: mcData.blocksByName.furnace.id, + maxDistance: 128, + }); + if (furnace) { + bot.chat( + `/setblock ${furnace.position.x} ${furnace.position.y} ${furnace.position.z} air destroy` + ); + bot.chat("/give @s furnace"); + } + if (bot.inventoryUsed() >= 32) { + // if chest is not in bot's inventory + if (!bot.inventory.items().find((item) => item.name === "chest")) { + bot.chat("/give @s chest"); + } + } + // if iron_pickaxe not in bot's inventory and bot.iron_pickaxe + if ( + bot.iron_pickaxe && + !bot.inventory.items().find((item) => item.name === "iron_pickaxe") + ) { + bot.chat("/give @s iron_pickaxe"); + } + bot.chat("/gamerule doTileDrops true"); + } + + function handleError(err) { + let stack = err.stack; + if (!stack) { + return err; + } + console.log(stack); + const final_line = stack.split("\n")[1]; + const regex = /:(\d+):\d+\)/; + + const programs_length = programs.split("\n").length; + let match_line = null; + for (const line of stack.split("\n")) { + const match = regex.exec(line); + if (match) { + const line_num = parseInt(match[1]); + if (line_num >= programs_length) { + match_line = line_num - programs_length; + break; + } + } + } + if (!match_line) { + return err.message; + } + let f_line = final_line.match( + /\((?.*):(?\d+):(?\d+)\)/ + ); + if (f_line && f_line.groups && fs.existsSync(f_line.groups.file)) { + const { file, line, pos } = f_line.groups; + const f = fs.readFileSync(file, "utf8").split("\n"); + // let filename = file.match(/(?<=node_modules\\)(.*)/)[1]; + let source = file + `:${line}\n${f[line - 1].trim()}\n `; + + const code_source = + "at " + + code.split("\n")[match_line - 1].trim() + + " in your code"; + return source + err.message + "\n" + code_source; + } else if ( + f_line && + f_line.groups && + f_line.groups.file.includes("") + ) { + const { file, line, pos } = f_line.groups; + let source = + "Your code" + + `:${match_line}\n${code.split("\n")[match_line - 1].trim()}\n `; + let code_source = ""; + if (line < programs_length) { + source = + "In your program code: " + + programs.split("\n")[line - 1].trim() + + "\n"; + code_source = `at line ${match_line}:${code + .split("\n") + [match_line - 1].trim()} in your code`; + } + return source + err.message + "\n" + code_source; + } + return err.message; + } +}); + +app.post("/stop", (req, res) => { + bot.end(); + res.json({ + message: "Bot stopped", + }); +}); + +app.post("/pause", (req, res) => { + if (!bot) { + res.status(400).json({ error: "Bot not spawned" }); + return; + } + bot.chat("/pause"); + bot.waitForTicks(bot.waitTicks).then(() => { + res.json({ message: "Success" }); + }); +}); + +// Server listening to PORT 3000 + +const DEFAULT_PORT = 3000; +const PORT = process.argv[2] || DEFAULT_PORT; +app.listen(PORT, () => { + console.log(`Server started on port ${PORT}`); +}); diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/base.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/base.js new file mode 100644 index 0000000000000000000000000000000000000000..b661a24b57c1a61b9ff09b9254ce72002212f5d3 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/base.js @@ -0,0 +1,45 @@ +class Observation { + constructor(bot) { + if (new.target === Observation) { + throw new TypeError( + "Cannot instantiate abstract class Observation" + ); + } + + this.bot = bot; + this.name = "Observation"; + } + + observe() { + throw new TypeError("Method 'observe()' must be implemented."); + } + + reset() {} +} + +function inject(bot, obs_list) { + bot.obsList = []; + bot.cumulativeObs = []; + bot.eventMemory = {}; + obs_list.forEach((obs) => { + bot.obsList.push(new obs(bot)); + }); + bot.event = function (event_name) { + let result = {}; + bot.obsList.forEach((obs) => { + if (obs.name.startsWith("on") && obs.name !== event_name) { + return; + } + result[obs.name] = obs.observe(); + }); + bot.cumulativeObs.push([event_name, result]); + }; + bot.observe = function () { + bot.event("observe"); + const result = bot.cumulativeObs; + bot.cumulativeObs = []; + return JSON.stringify(result); + }; +} + +module.exports = { Observation, inject }; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/chests.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/chests.js new file mode 100644 index 0000000000000000000000000000000000000000..842bd171d579d77a328615787e0309d0b40eb1fe --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/chests.js @@ -0,0 +1,31 @@ +const { Observation } = require("./base"); + +class Chests extends Observation { + constructor(bot) { + super(bot); + this.name = "nearbyChests"; + this.chestsItems = {}; + bot.on("closeChest", (chestItems, position) => { + this.chestsItems[position] = chestItems; + }); + bot.on("removeChest", (chestPosition) => { + this.chestsItems[chestPosition] = "Invalid"; + }); + } + + observe() { + const chests = this.bot.findBlocks({ + matching: this.bot.registry.blocksByName.chest.id, + maxDistance: 16, + count: 999, + }); + chests.forEach((chest) => { + if (!this.chestsItems.hasOwnProperty(chest)) { + this.chestsItems[chest] = "Unknown"; + } + }); + return this.chestsItems; + } +} + +module.exports = Chests; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/inventory.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/inventory.js new file mode 100644 index 0000000000000000000000000000000000000000..0645d1bfa0803e155e3987d3d526f2b43d8f5936 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/inventory.js @@ -0,0 +1,39 @@ +const { Observation } = require("./base"); + +class Inventory extends Observation { + constructor(bot) { + super(bot); + this.name = "inventory"; + } + + observe() { + return listItems(this.bot); + } +} + +function listItems(bot) { + const items = getInventoryItems(bot); + return items.reduce(itemToDict, {}); +} + +function getInventoryItems(bot) { + const inventory = bot.currentWindow || bot.inventory; + return inventory.items(); +} + +function itemToDict(acc, cur) { + if (cur.name && cur.count) { + //if both name and count property are defined + if (acc[cur.name]) { + //if the item is already in the dict + acc[cur.name] += cur.count; + } else { + //if the item is not in the dict + acc[cur.name] = cur.count; + } + } + return acc; +} + +//export modules +module.exports = Inventory; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onChat.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onChat.js new file mode 100644 index 0000000000000000000000000000000000000000..54b411e2ad903ca54e4cdbf2b9d8732df82a55f8 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onChat.js @@ -0,0 +1,26 @@ +const Observation = require("./base.js").Observation; + +class onChat extends Observation { + constructor(bot) { + super(bot); + this.name = "onChat"; + this.obs = ""; + bot.on("chatEvent", (username, message) => { + // Save entity status to local variable + if (message.startsWith("/")) { + return; + } + + this.obs += message; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = ""; + return result; + } +} + +module.exports = onChat; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onError.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onError.js new file mode 100644 index 0000000000000000000000000000000000000000..ac8fed9e51937c33105068e2c45800fe1c022c89 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onError.js @@ -0,0 +1,22 @@ +const Observation = require("./base.js").Observation; + +class onError extends Observation { + constructor(bot) { + super(bot); + this.name = "onError"; + this.obs = null; + bot.on("error", (err) => { + // Save entity status to local variable + this.obs = err; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = null; + return result; + } +} + +module.exports = onError; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onSave.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onSave.js new file mode 100644 index 0000000000000000000000000000000000000000..e5983590ff7b5829b7a9679fee7a11f04f3cc5a7 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/onSave.js @@ -0,0 +1,22 @@ +const Observation = require("./base.js").Observation; + +class onSave extends Observation { + constructor(bot) { + super(bot); + this.name = "onSave"; + this.obs = null; + bot.on("save", (eventName) => { + // Save entity status to local variable + this.obs = eventName; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = null; + return result; + } +} + +module.exports = onSave; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/status.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/status.js new file mode 100644 index 0000000000000000000000000000000000000000..b031fbcf20d307bdd7895de1b29e589b10d33b40 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/status.js @@ -0,0 +1,103 @@ +const Observation = require("./base.js").Observation; + +class Status extends Observation { + constructor(bot) { + super(bot); + this.name = "status"; + } + + observe() { + return { + health: this.bot.health, + food: this.bot.food, + saturation: this.bot.foodSaturation, + oxygen: this.bot.oxygenLevel, + position: this.bot.entity.position, + velocity: this.bot.entity.velocity, + yaw: this.bot.entity.yaw, + pitch: this.bot.entity.pitch, + onGround: this.bot.entity.onGround, + equipment: this.getEquipment(), + name: this.bot.entity.username, + timeSinceOnGround: this.bot.entity.timeSinceOnGround, + isInWater: this.bot.entity.isInWater, + isInLava: this.bot.entity.isInLava, + isInWeb: this.bot.entity.isInWeb, + isCollidedHorizontally: this.bot.entity.isCollidedHorizontally, + isCollidedVertically: this.bot.entity.isCollidedVertically, + biome: this.bot.blockAt(this.bot.entity.position) + ? this.bot.blockAt(this.bot.entity.position).biome.name + : "None", + entities: this.getEntities(), + timeOfDay: this.getTime(), + inventoryUsed: this.bot.inventoryUsed(), + elapsedTime: this.bot.globalTickCounter, + }; + } + + itemToObs(item) { + if (!item) return null; + return item.name; + } + + getTime() { + const timeOfDay = this.bot.time.timeOfDay; + let time = ""; + if (timeOfDay < 1000) { + time = "sunrise"; + } else if (timeOfDay < 6000) { + time = "day"; + } else if (timeOfDay < 12000) { + time = "noon"; + } else if (timeOfDay < 13000) { + time = "sunset"; + } else if (timeOfDay < 18000) { + time = "night"; + } else if (timeOfDay < 22000) { + time = "midnight"; + } else { + time = "sunrise"; + } + return time; + } + + // For each item in equipment, if it exists, return the name of the item + // otherwise return null + getEquipment() { + const slots = this.bot.inventory.slots; + const mainHand = this.bot.heldItem; + return slots + .slice(5, 9) + .concat(mainHand, slots[45]) + .map(this.itemToObs); + } + + getEntities() { + const entities = this.bot.entities; + if (!entities) return {}; + // keep all monsters in one list, keep other mobs in another list + const mobs = {}; + for (const id in entities) { + const entity = entities[id]; + if (!entity.displayName) continue; + if (entity.name === "player" || entity.name === "item") continue; + if (entity.position.distanceTo(this.bot.entity.position) < 32) { + if (!mobs[entity.name]) { + mobs[entity.name] = entity.position.distanceTo( + this.bot.entity.position + ); + } else if ( + mobs[entity.name] > + entity.position.distanceTo(this.bot.entity.position) + ) { + mobs[entity.name] = entity.position.distanceTo( + this.bot.entity.position + ); + } + } + } + return mobs; + } +} + +module.exports = Status; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/voxels.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/voxels.js new file mode 100644 index 0000000000000000000000000000000000000000..ecb0c14b70d4b48034fd4af452bb7572073db878 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/observation/voxels.js @@ -0,0 +1,67 @@ +// Blocks = require("./blocks") +const { Observation } = require("./base"); + +class Voxels extends Observation { + constructor(bot) { + super(bot); + this.name = "voxels"; + } + + observe() { + return Array.from(getSurroundingBlocks(this.bot, 8, 2, 8)); + } +} + +class BlockRecords extends Observation { + constructor(bot) { + super(bot); + this.name = "blockRecords"; + this.records = new Set(); + this.tick = 0; + bot.on("physicsTick", () => { + this.tick++; + if (this.tick >= 100) { + const items = getInventoryItems(this.bot); + getSurroundingBlocks(this.bot, 8, 2, 8).forEach((block) => { + if (!items.has(block)) this.records.add(block); + }); + this.tick = 0; + } + }); + } + + observe() { + return Array.from(this.records); + } + + reset() { + this.records = new Set(); + } +} + +function getSurroundingBlocks(bot, x_distance, y_distance, z_distance) { + const surroundingBlocks = new Set(); + + for (let x = -x_distance; x <= x_distance; x++) { + for (let y = -y_distance; y <= y_distance; y++) { + for (let z = -z_distance; z <= z_distance; z++) { + const block = bot.blockAt(bot.entity.position.offset(x, y, z)); + if (block && block.type !== 0) { + surroundingBlocks.add(block.name); + } + } + } + } + // console.log(surroundingBlocks); + return surroundingBlocks; +} + +function getInventoryItems(bot) { + const items = new Set(); + bot.inventory.items().forEach((item) => { + if (item) items.add(item.name); + }); + return items; +} + +module.exports = { Voxels, BlockRecords }; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/skillLoader.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/skillLoader.js new file mode 100644 index 0000000000000000000000000000000000000000..d78cf782093b213b35d3d4c719490e3a86a7878b --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/skillLoader.js @@ -0,0 +1,79 @@ +function inject(bot) { + bot._sleep = bot.sleep; + bot.sleep = async (bedBlock) => { + await bot.waitForTicks(20); + await bot._sleep(bedBlock); + await bot.waitForTicks(135); + }; + + bot._fish = bot.fish; + bot.fish = async () => { + if (bot.heldItem?.name !== "fishing_rod") { + bot.chat("I'm not holding a fishing rod!"); + return; + } + let timeout = null; + await Promise.race([ + bot._fish(), + new Promise( + (resolve, reject) => + (timeout = setTimeout(() => { + bot.activateItem(); + reject( + new Error( + "Finishing timeout, make sure you get to and look at a water block!" + ) + ); + }, 60000)) + ), + ]); + clearTimeout(timeout); + await bot.waitForTicks(20); + }; + + bot._consume = bot.consume; + bot.consume = async () => { + // action_count.activateItem++; + await bot._consume(); + await bot.waitForTicks(20); + }; + + bot._useOn = bot.useOn; + bot.useOn = async (entity) => { + if (entity.position.distanceTo(bot.entity.position) > 6) { + bot.chat("Please goto a place near the entity first!"); + return; + } + await bot._useOn(entity); + await bot.waitForTicks(20); + }; + + bot._activateBlock = bot.activateBlock; + bot.activateBlock = async (block) => { + if (block.position.distanceTo(bot.entity.position) > 6) { + bot.chat("Please goto a place near the block first!"); + return; + } + // action_count.activateBlock++; + await bot._activateBlock(block); + }; + + bot._chat = bot.chat; + bot.chat = (message) => { + // action_count.chat++; + bot.emit("chatEvent", "bot", message); + bot._chat(message); + }; + + bot.inventoryUsed = () => { + return bot.inventory.slots.slice(9, 45).filter((item) => item !== null) + .length; + }; + + bot.save = function (eventName) { + bot.emit("save", eventName); + }; +} + +// export all control_primitives +module.exports = { inject }; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/utils.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/utils.js new file mode 100644 index 0000000000000000000000000000000000000000..68af3079602ab8d88059a9c8c4055140dda32f1d --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/lib/utils.js @@ -0,0 +1,31 @@ +let gameTimeCounter = 0; +let gameTimeList = []; +const initCounter = (bot) => { + gameTimeList = []; + for (let i = 0; i < 13000; i += 1000) { + gameTimeList.push(i); + } + for (let i = 13000; i < 24000; i += 2000) { + gameTimeList.push(i); + } + const timeOfDay = bot.time.timeOfDay; + for (let i = 0; i < gameTimeList.length; i++) { + if (gameTimeList[i] > timeOfDay) { + gameTimeCounter = i - 1; + break; + } + } +}; + +const getNextTime = () => { + gameTimeCounter++; + if (gameTimeCounter >= gameTimeList.length) { + gameTimeCounter = 0; + } + return gameTimeList[gameTimeCounter]; +}; + +module.exports = { + initCounter, + getNextTime, +}; diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/.gitignore b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0578fdca3844dbbdfdabfa5c927de3a1144d7d5a --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/.gitignore @@ -0,0 +1,107 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# TypeScript v1 declaration files +typings/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test + +# parcel-bundler cache (https://parceljs.org/) +.cache + +# Next.js build output +.next + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and *not* Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +lib/ +package-lock.json diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/LICENSE b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f2896b56e45adc3d54cd6f98764d4b155b571217 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 TheDudeFromCI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/README.md b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/README.md new file mode 100644 index 0000000000000000000000000000000000000000..555acb761e51efff08f372cb2525c8da2a230e57 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/README.md @@ -0,0 +1,89 @@ +

mineflayer-collectblock

+

A small utility plugin for allowing users to collect blocks using a higher level API.

+ +

+ + + + + + +

+ +--- +## This is a modified version to better support Voyager + +## Showcase + +You can see a video of the plugin in action, [here.](https://youtu.be/5T_rcCnNnf4) +The source code of the bot in the video can be seen in the examples folder, [here.](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/examples/collector.js) + +### Description + +This plugin is a wrapper for mineflayer that allows for easier API usage when collecting blocks or item drops. This plugin is designed to reduce some of the boilerplate code based around the act of pathfinding to a block _(handled by_ ***mineflayer-pathfinder***_)_, selecting the best tool to mine that block _(handled by_ ***mineflayer-tool***_)_, actually mining it, then moving to collect the item drops from that block. This plugin allows for all of that basic concept to be wrapped up into a single API function. + +In addition to the usage above, some additional quality of life features are available in this plugin. These include the ability to automatically deposit items into a chest when the bot's inventory is full, collecting new tools from a chest if the bot doesn't currently have a required tool _(also handled by_ ***mineflayer-tool***_)_, and allowing for queueing of multiple blocks or item drops to the collection task, so they can be processed later. + +### Getting Started + +This plugin is built using Node and can be installed using: +```bash +npm install --save mineflayer-collectblock +``` + +### Simple Bot + +The brief description goes here. + +```js +// Create your bot +const mineflayer = require("mineflayer") +const bot = mineflayer.createBot({ + host: 'localhost', + username: 'Player', +}) +let mcData + +// Load collect block +bot.loadPlugin(require('mineflayer-collectblock').plugin) + +async function collectGrass() { + // Find a nearby grass block + const grass = bot.findBlock({ + matching: mcData.blocksByName.grass_block.id, + maxDistance: 64 + }) + + if (grass) { + // If we found one, collect it. + try { + await bot.collectBlock.collect(grass) + collectGrass() // Collect another grass block + } catch (err) { + console.log(err) // Handle errors, if any + } + } +} + +// On spawn, start collecting all nearby grass +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) + collectGrass() +}) +``` + +### Documentation + +[API](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/docs/api.md) + +[Examples](https://github.com/TheDudeFromCI/mineflayer-collectblock/tree/master/examples) + +### License + +This project uses the [MIT](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/LICENSE) license. + +### Contributions + +This project is accepting PRs and Issues. See something you think can be improved? Go for it! Any and all help is highly appreciated! + +For larger changes, it is recommended to discuss these changes in the issues tab before writing any code. It's also preferred to make many smaller PRs than one large one, where applicable. diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/_config.yml b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4192631f25b34d77a7f159aa0da0e3ae99c4ef4 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/docs/api.md b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..66d8a3ecc4a441ff3e989412fc1520e5ffdc1e17 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/docs/api.md @@ -0,0 +1,52 @@ +# API + +Welcome to the *mineflayer-collectblock* API documentation page. + +## Table of Contents + +- [1. Summary](#1-summary) +- [Properties](#properties) + - [`bot.collectblock.movements: Movements`](#botcollectblockmovements-movements) +- [Functions](#functions) + - [collect](#collect) + - [Options:](#options) + +## 1. Summary + +The collect block plugin is a utility plugin that can be used to help make collecting blocks and item drops very easy, using only a single API call. No need to worry about pathfinding to the block, selecting the right tool, or moving to pick up the item drop after mining. + +## Properties + +### `bot.collectblock.movements: Movements` + +The movements object used by the pathfinder plugin to define the movement configuration. This object is passed to the pathfinder plugin when any API from this plugin is called in order to control how pathfinding should work when collecting the given blocks or item. + +If set to null, the pathfinder plugin movements is not updated. + +Defaults to a new movements object instance. + +## Functions + +### collect + +Usage: `bot.collectblock.collect(target: Collectable | Collectable[], options?: CollectOptions, cb: (err?: Error) => void): void` + +Causes the bot to collect the given block, item drop, or list of those. If the target is a block, the bot will move to the block, mine it, and pick up the item drop. If the target is an item drop, the bot will move to the item drop and pick it up. If the target is a list of collectables, the bot will move from target to target in order of closest to furthest and collect each target in turn. + +#### Options: + + * `append: boolean` + + If true, the target(s) will be appended to the existing target list instead of starting a new task. Defaults to false. + + * `ignoreNoPath: boolean` + + If true, errors will not be thrown when a path to the target block cannot be found. The bot will attempt to choose the best available position it can find, instead. Errors are still thrown if the bot cannot interact with the block from it's final location. Defaults to false. + + * `chestLocations: Vec3[]` + + Gets the list of chest locations to use when storing items after the bot's inventory becomes full. If undefined, it defaults to the chest location list on the bot.collectBlock plugin. + + * `itemFilter: ItemFilter` + + When transferring items to a chest, this filter is used to determine what items are allowed to be moved, and what items aren't allowed to be moved. Defaults to the item filter specified on the bot.collectBlock plugin. \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/collector.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/collector.js new file mode 100644 index 0000000000000000000000000000000000000000..b9bb8faf9e73762856eed9d41f0da027728e82b3 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/collector.js @@ -0,0 +1,70 @@ +/** + * This bot example show how to direct a bot to collect a specific block type + * or a group of nearby blocks of that type. + */ + +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node collector.js [] []') + process.exit(1) +} + +const bot = mineflayer.createBot({ + host: process.argv[2], + port: process.argv[3], + username: process.argv[4] || 'collector', + password: process.argv[5] +}) + +bot.loadPlugin(collectBlock) + +let mcData +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) +}) + +bot.on('chat', async (username, message) => { + const args = message.split(' ') + if (args[0] !== 'collect') return + + let count = 1 + if (args.length === 3) count = parseInt(args[1]) + + let type = args[1] + if (args.length === 3) type = args[2] + + const blockType = mcData.blocksByName[type] + if (!blockType) { + return + } + + const blocks = bot.findBlocks({ + matching: blockType.id, + maxDistance: 64, + count: count + }) + + if (blocks.length === 0) { + bot.chat("I don't see that block nearby.") + return + } + + const targets = [] + for (let i = 0; i < Math.min(blocks.length, count); i++) { + targets.push(bot.blockAt(blocks[i])) + } + + bot.chat(`Found ${targets.length} ${type}(s)`) + + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/oreMiner.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/oreMiner.js new file mode 100644 index 0000000000000000000000000000000000000000..6accac88fd3c3e29ac431c497d618d2f27f23c67 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/oreMiner.js @@ -0,0 +1,59 @@ +/** + * This bot example shows how to collect a vein of ores quickly after only finding a single block. + * This makes it easy to collect a vein of ores or mine a tree without looking for every block in the + * area. + */ + +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node oreMiner.js [] []') + process.exit(1) +} + +const bot = mineflayer.createBot({ + host: process.argv[2], + port: process.argv[3], + username: process.argv[4] || 'oreMiner', + password: process.argv[5] +}) + +bot.loadPlugin(collectBlock) + +let mcData +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) +}) + +bot.on('chat', async (username, message) => { + const args = message.split(' ') + if (args[0] !== 'collect') return + + const blockType = mcData.blocksByName[args[1]] + if (!blockType) { + bot.chat(`I don't know any blocks named ${args[1]}.`) + return + } + + const block = bot.findBlock({ + matching: blockType.id, + maxDistance: 64 + }) + + if (!block) { + bot.chat("I don't see that block nearby.") + return + } + + const targets = bot.collectBlock.findFromVein(block) + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/storageBot.js b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/storageBot.js new file mode 100644 index 0000000000000000000000000000000000000000..b6f9971f25103612d6dd529fc4f4b42a710f1b1f --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/examples/storageBot.js @@ -0,0 +1,107 @@ +/** + * This bot example shows how to use the chest filling mechanic of the plugin. + * Simply provide a given storage chest, and the bot will automatically try and + * store it's inventory in that chest when the bot's inventory becomes full. + */ + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node storageBot.js [] []') + process.exit(1) +} + +// Load your libraries +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +// Create your bot +const bot = mineflayer.createBot({ + host: process.argv[2], + port: parseInt(process.argv[3]), + username: process.argv[4] ? process.argv[4] : 'storageBot', + password: process.argv[5] +}) + +// Load the collect block plugin +bot.loadPlugin(collectBlock) + +// Load mcData on login +let mcData +bot.once('login', () => { + mcData = require('minecraft-data')(bot.version) +}) + +// On spawn, try to find any nearby chests and save those as storage locations. +// When the bot's inventory becomes too full, it will empty it's inventory into +// these chests before collecting more resources. If a chest gets full, it moves +// to the next one in order until it's inventory is empty or it runs out of chests. +bot.once('spawn', () => { + bot.collectBlock.chestLocations = bot.findBlocks({ + matching: mcData.blocksByName.chest.id, + maxDistance: 16, + count: 999999 // Get as many chests as we can + }) + + if (bot.collectBlock.chestLocations.length === 0) { + bot.chat("I don't see any chests nearby.") + } else { + for (const chestPos of bot.collectBlock.chestLocations) { + bot.chat(`I found a chest at ${chestPos}`) + } + } +}) + +// Wait for someone to say something +bot.on('chat', async (username, message) => { + // If the player says something start starts with "collect" + // Otherwise, do nothing + const args = message.split(' ') + if (args[0] !== 'collect') return + + // If the player specifies a number, collect that many. Otherwise, default to 1. + let count = 1 + if (args.length === 3) count = parseInt(args[1]) + + // If a number was given the item number is the 3rd arg, not the 2nd. + let type = args[1] + if (args.length === 3) type = args[2] + + // Get the id of that block type for this version of Minecraft. + const blockType = mcData.blocksByName[type] + if (!blockType) { + bot.chat(`I don't know any blocks named ${type}.`) + return + } + + // Find all nearby blocks of that type, up to the given count, within 64 blocks. + const blocks = bot.findBlocks({ + matching: blockType.id, + maxDistance: 64, + count: count + }) + + // Complain if we can't find any nearby blocks of that type. + if (blocks.length === 0) { + bot.chat("I don't see that block nearby.") + return + } + + // Convert the block position array into a block array to pass to collect block. + const targets = [] + for (let i = 0; i < Math.min(blocks.length, count); i++) { + targets.push(bot.blockAt(blocks[i])) + } + + // Announce what we found. + bot.chat(`Found ${targets.length} ${type}(s)`) + + // Tell the bot to collect all of the given blocks in the block list. + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/package.json b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/package.json new file mode 100644 index 0000000000000000000000000000000000000000..0f59e7aa6a1d38ed4c43923f910846d6c7998ec8 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/package.json @@ -0,0 +1,44 @@ +{ + "name": "mineflayer-collectblock", + "version": "1.4.1", + "description": "A simple utility plugin for Mineflayer that add a higher level API for collecting blocks.", + "main": "lib/index.js", + "types": "lib/index.d.ts", + "scripts": { + "build": "ts-standard && tsc && require-self", + "clean": "rm -rf lib", + "test": "test" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/TheDudeFromCI/mineflayer-collectblock.git" + }, + "keywords": [ + "mineflayer", + "plugin", + "api", + "utility", + "helper", + "collect" + ], + "author": "TheDudeFromCI", + "license": "MIT", + "bugs": { + "url": "https://github.com/TheDudeFromCI/mineflayer-collectblock/issues" + }, + "homepage": "https://github.com/TheDudeFromCI/mineflayer-collectblock#readme", + "dependencies": { + "mineflayer": "^4.0.0", + "mineflayer-pathfinder": "^2.1.1", + "mineflayer-tool": "^1.1.0" + }, + "devDependencies": { + "@types/node": "^18.6.4", + "require-self": "^0.2.3", + "ts-standard": "^11.0.0", + "typescript": "^4.1.3" + }, + "files": [ + "lib/**/*" + ] +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/BlockVeins.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/BlockVeins.ts new file mode 100644 index 0000000000000000000000000000000000000000..ae5542ce3a693d75262bea010b72766e3042fd0b --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/BlockVeins.ts @@ -0,0 +1,35 @@ +import { Bot } from 'mineflayer' +import { Block } from 'prismarine-block' + +export function findFromVein (bot: Bot, block: Block, maxBlocks: number, maxDistance: number, floodRadius: number): Block[] { + const targets: Block[] = [] + const open: Block[] = [block] + const type = block.type + const center = block.position + + for (let i = 0; i < maxBlocks; i++) { + const next = open.pop() + if (next == null) break + + targets.push(next) + + for (let x = -floodRadius; x <= floodRadius; x++) { + for (let y = -floodRadius; y <= floodRadius; y++) { + for (let z = -floodRadius; z <= floodRadius; z++) { + const neighborPos = next.position.offset(x, y, z) + if (neighborPos.manhattanDistanceTo(center) > maxDistance) continue + + const neighbor = bot.blockAt(neighborPos) + if (neighbor == null || neighbor.type !== type) continue + + if (targets.includes(neighbor)) continue + if (open.includes(neighbor)) continue + + open.push(neighbor) + } + } + } + } + + return targets +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/CollectBlock.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/CollectBlock.ts new file mode 100644 index 0000000000000000000000000000000000000000..d2be87822f9ab6fffe64aeae777933a3f0e61d29 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/CollectBlock.ts @@ -0,0 +1,451 @@ +import { Bot } from "mineflayer"; +import { Block } from "prismarine-block"; +import { Movements, goals } from "mineflayer-pathfinder"; +import { TemporarySubscriber } from "./TemporarySubscriber"; +import { Entity } from "prismarine-entity"; +import { error } from "./Util"; +import { Vec3 } from "vec3"; +import { emptyInventoryIfFull, ItemFilter } from "./Inventory"; +import { findFromVein } from "./BlockVeins"; +import { Collectable, Targets } from "./Targets"; +import { Item } from "prismarine-item"; +import mcDataLoader from "minecraft-data"; +import { once } from "events"; +import { callbackify } from "util"; + +export type Callback = (err?: Error) => void; + +async function collectAll( + bot: Bot, + options: CollectOptionsFull +): Promise { + let success_count = 0; + while (!options.targets.empty) { + await emptyInventoryIfFull( + bot, + options.chestLocations, + options.itemFilter + ); + const closest = options.targets.getClosest(); + if (closest == null) break; + switch (closest.constructor.name) { + case "Block": { + try { + if (success_count >= options.count) { + break; + } + await bot.tool.equipForBlock( + closest as Block, + equipToolOptions + ); + const goal = new goals.GoalLookAtBlock( + closest.position, + bot.world + ); + await bot.pathfinder.goto(goal); + await mineBlock(bot, closest as Block, options); + success_count++; + // TODO: options.ignoreNoPath + } catch (err) { + // @ts-ignore + // console.log(err.stack) + // bot.pathfinder.stop() + // bot.waitForTicks(10) + try { + bot.pathfinder.setGoal(null); + } catch (err) {} + if (options.ignoreNoPath) { + // @ts-ignore + if (err.name === "Invalid block") { + console.log( + `Block ${closest.name} at ${closest.position} is not valid! Skip it!` + ); + } // @ts-ignore + else if (err.name === "Unsafe block") { + console.log( + `${closest.name} at ${closest.position} is not safe to break! Skip it!` + ); + // @ts-ignore + } else if (err.name === "NoItem") { + const properties = + bot.registry.blocksByName[closest.name]; + const leastTool = Object.keys( + properties.harvestTools + )[0]; + const item = bot.registry.items[leastTool]; + bot.chat( + `I need at least a ${item.name} to mine ${closest.name}! Skip it!` + ); + return; + } else if ( + // @ts-ignore + err.name === "NoPath" || + // @ts-ignore + err.name === "Timeout" + ) { + if ( + bot.entity.position.distanceTo( + closest.position + ) < 0.5 + ) { + await mineBlock(bot, closest as Block, options); + break; + } + console.log( + `No path to ${closest.name} at ${closest.position}! Skip it!` + ); + // @ts-ignore + } else if (err.message === "Digging aborted") { + console.log(`Digging aborted! Skip it!`); + } else { + // @ts-ignore + bot.chat(`Error: ${err.message}`); + } + break; + } + throw err; + } + break; + } + case "Entity": { + // Don't collect any entities that are marked as 'invalid' + if (!(closest as Entity).isValid) break; + try { + const tempEvents = new TemporarySubscriber(bot); + const waitForPickup = new Promise( + (resolve, reject) => { + const timeout = setTimeout(() => { + // After 10 seconds, reject the promise + clearTimeout(timeout); + tempEvents.cleanup(); + reject(new Error("Failed to pickup item")); + }, 10000); + tempEvents.subscribeTo( + "entityGone", + (entity: Entity) => { + if (entity === closest) { + clearTimeout(timeout); + tempEvents.cleanup(); + resolve(); + } + } + ); + } + ); + bot.pathfinder.setGoal( + new goals.GoalFollow(closest as Entity, 0) + ); + // await bot.pathfinder.goto(new goals.GoalBlock(closest.position.x, closest.position.y, closest.position.z)) + await waitForPickup; + } catch (err) { + // @ts-ignore + console.log(err.stack); + try { + bot.pathfinder.setGoal(null); + } catch (err) {} + if (options.ignoreNoPath) { + // @ts-ignore + if (err.message === "Failed to pickup item") { + bot.chat(`Failed to pickup item! Skip it!`); + } + break; + } + throw err; + } + break; + } + default: { + throw error( + "UnknownType", + `Target ${closest.constructor.name} is not a Block or Entity!` + ); + } + } + options.targets.removeTarget(closest); + } + bot.chat(`Collect finish!`); +} + +const equipToolOptions = { + requireHarvest: true, + getFromChest: false, + maxTools: 2, +}; + +async function mineBlock( + bot: Bot, + block: Block, + options: CollectOptionsFull +): Promise { + if ( + bot.blockAt(block.position)?.type !== block.type || + bot.blockAt(block.position)?.type === 0 + ) { + options.targets.removeTarget(block); + throw error("Invalid block", "Block is not valid!"); + // @ts-expect-error + } else if (!bot.pathfinder.movements.safeToBreak(block)) { + options.targets.removeTarget(block); + throw error("Unsafe block", "Block is not safe to break!"); + } + + await bot.tool.equipForBlock(block, equipToolOptions); + + if (!block.canHarvest(bot.heldItem ? bot.heldItem.type : bot.heldItem)) { + options.targets.removeTarget(block); + throw error("NoItem", "Bot does not have a harvestable tool!"); + } + + const tempEvents = new TemporarySubscriber(bot); + tempEvents.subscribeTo("itemDrop", (entity: Entity) => { + if ( + entity.position.distanceTo(block.position.offset(0.5, 0.5, 0.5)) <= + 0.5 + ) { + options.targets.appendTarget(entity); + } + }); + try { + await bot.dig(block); + // Waiting for items to drop + await new Promise((resolve) => { + let remainingTicks = 10; + tempEvents.subscribeTo("physicTick", () => { + remainingTicks--; + if (remainingTicks <= 0) { + tempEvents.cleanup(); + resolve(); + } + }); + }); + } finally { + tempEvents.cleanup(); + } +} + +/** + * A set of options to apply when collecting the given targets. + */ +export interface CollectOptions { + /** + * If true, the target(s) will be appended to the existing target list instead of + * starting a new task. Defaults to false. + */ + append?: boolean; + + /** + * If true, errors will not be thrown when a path to the target block cannot + * be found. The bot will attempt to choose the best available position it + * can find, instead. Errors are still thrown if the bot cannot interact with + * the block from it's final location. Defaults to false. + */ + ignoreNoPath?: boolean; + + /** + * Gets the list of chest locations to use when storing items after the bot's + * inventory becomes full. If undefined, it defaults to the chest location + * list on the bot.collectBlock plugin. + */ + chestLocations?: Vec3[]; + + /** + * When transferring items to a chest, this filter is used to determine what + * items are allowed to be moved, and what items aren't allowed to be moved. + * Defaults to the item filter specified on the bot.collectBlock plugin. + */ + itemFilter?: ItemFilter; + + /** + * The total number of items to collect + */ + count?: number; +} + +/** + * A version of collect options where all values are assigned. + */ +interface CollectOptionsFull { + append: boolean; + ignoreNoPath: boolean; + chestLocations: Vec3[]; + itemFilter: ItemFilter; + targets: Targets; + count: number; +} + +/** + * The collect block plugin. + */ +export class CollectBlock { + /** + * The bot. + */ + private readonly bot: Bot; + + /** + * The list of active targets being collected. + */ + private readonly targets: Targets; + + /** + * The movements configuration to be sent to the pathfinder plugin. + */ + movements?: Movements; + + /** + * A list of chest locations which the bot is allowed to empty their inventory into + * if it becomes full while the bot is collecting resources. + */ + chestLocations: Vec3[] = []; + + /** + * When collecting items, this filter is used to determine what items should be placed + * into a chest if the bot's inventory becomes full. By default, returns true for all + * items except for tools, weapons, and armor. + * + * @param item - The item stack in the bot's inventory to check. + * + * @returns True if the item should be moved into the chest. False otherwise. + */ + itemFilter: ItemFilter = (item: Item) => { + if (item.name.includes("helmet")) return false; + if (item.name.includes("chestplate")) return false; + if (item.name.includes("leggings")) return false; + if (item.name.includes("boots")) return false; + if (item.name.includes("shield")) return false; + if (item.name.includes("sword")) return false; + if (item.name.includes("pickaxe")) return false; + if (item.name.includes("axe")) return false; + if (item.name.includes("shovel")) return false; + if (item.name.includes("hoe")) return false; + return true; + }; + + /** + * Creates a new instance of the create block plugin. + * + * @param bot - The bot this plugin is acting on. + */ + constructor(bot: Bot) { + this.bot = bot; + this.targets = new Targets(bot); + // @ts-ignore + this.movements = new Movements(bot, mcDataLoader(bot.version)); + } + + /** + * If target is a block: + * Causes the bot to break and collect the target block. + * + * If target is an item drop: + * Causes the bot to collect the item drop. + * + * If target is an array containing items or blocks, preforms the correct action for + * all targets in that array sorting dynamically by distance. + * + * @param target - The block(s) or item(s) to collect. + * @param options - The set of options to use when handling these targets + * @param cb - The callback that is called finished. + */ + async collect( + target: Collectable | Collectable[], + options: CollectOptions | Callback = {}, + cb?: Callback + ): Promise { + if (typeof options === "function") { + cb = options; + options = {}; + } + // @ts-expect-error + if (cb != null) return callbackify(this.collect)(target, options, cb); + + const optionsFull: CollectOptionsFull = { + append: options.append ?? false, + ignoreNoPath: options.ignoreNoPath ?? false, + chestLocations: options.chestLocations ?? this.chestLocations, + itemFilter: options.itemFilter ?? this.itemFilter, + targets: this.targets, + count: options.count ?? Infinity, + }; + + if (this.bot.pathfinder == null) { + throw error( + "UnresolvedDependency", + "The mineflayer-collectblock plugin relies on the mineflayer-pathfinder plugin to run!" + ); + } + + if (this.bot.tool == null) { + throw error( + "UnresolvedDependency", + "The mineflayer-collectblock plugin relies on the mineflayer-tool plugin to run!" + ); + } + + if (this.movements != null) { + this.bot.pathfinder.setMovements(this.movements); + } + + if (!optionsFull.append) await this.cancelTask(); + if (Array.isArray(target)) { + this.targets.appendTargets(target); + } else { + this.targets.appendTarget(target); + } + + try { + await collectAll(this.bot, optionsFull); + this.targets.clear(); + } catch (err) { + this.targets.clear(); + // Ignore path stopped error for cancelTask to work properly (imo we shouldn't throw any pathing errors) + // @ts-expect-error + if (err.name !== "PathStopped") throw err; + } finally { + // @ts-expect-error + this.bot.emit("collectBlock_finished"); + } + } + + /** + * Loads all touching blocks of the same type to the given block and returns them as an array. + * This effectively acts as a flood fill algorithm to retrieve blocks in the same ore vein and similar. + * + * @param block - The starting block. + * @param maxBlocks - The maximum number of blocks to look for before stopping. + * @param maxDistance - The max distance from the starting block to look. + * @param floodRadius - The max distance distance from block A to block B to be considered "touching" + */ + findFromVein( + block: Block, + maxBlocks = 100, + maxDistance = 16, + floodRadius = 1 + ): Block[] { + return findFromVein( + this.bot, + block, + maxBlocks, + maxDistance, + floodRadius + ); + } + + /** + * Cancels the current collection task, if still active. + * + * @param cb - The callback to use when the task is stopped. + */ + async cancelTask(cb?: Callback): Promise { + if (this.targets.empty) { + if (cb != null) cb(); + return await Promise.resolve(); + } + this.bot.pathfinder.stop(); + if (cb != null) { + // @ts-expect-error + this.bot.once("collectBlock_finished", cb); + } + await once(this.bot, "collectBlock_finished"); + } +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Inventory.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Inventory.ts new file mode 100644 index 0000000000000000000000000000000000000000..6a17d0cc525966d26e948d627febd567abf3dbc6 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Inventory.ts @@ -0,0 +1,87 @@ +import { Bot } from 'mineflayer' +import { Callback } from './CollectBlock' +import { Vec3 } from 'vec3' +import { error } from './Util' +import { Item } from 'prismarine-item' +import { goals } from 'mineflayer-pathfinder' +import { callbackify } from 'util' + +export type ItemFilter = (item: Item) => boolean + +function getClosestChest (bot: Bot, chestLocations: Vec3[]): Vec3 | null { + let chest = null + let distance = 0 + + for (const c of chestLocations) { + const dist = c.distanceTo(bot.entity.position) + if (chest == null || dist < distance) { + chest = c + distance = dist + } + } + + if (chest != null) { + chestLocations.splice(chestLocations.indexOf(chest), 1) + } + + return chest +} + +export async function emptyInventoryIfFull (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(emptyInventoryIfFull)(bot, chestLocations, cb) + if (bot.inventory.emptySlotCount() > 0) return + return await emptyInventory(bot, chestLocations, itemFilter) +} + +export async function emptyInventory (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(emptyInventory)(bot, chestLocations, cb) + if (chestLocations.length === 0) { + throw error('NoChests', 'There are no defined chest locations!') + } + + // Shallow clone so we can safely remove chests from the list that are full. + chestLocations = [...chestLocations] + + while (true) { + const chest = getClosestChest(bot, chestLocations) + if (chest == null) { + throw error('NoChests', 'All chests are full.') + } + const hasRemaining = await tryEmptyInventory(bot, chest, itemFilter) + if (!hasRemaining) return + } +} + +async function tryEmptyInventory (bot: Bot, chestLocation: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise { + // @ts-expect-error + if (cb != null) return callbackify(tryEmptyInventory)(bot, chestLocation, itemFilter, cb) + await gotoChest(bot, chestLocation) + return await placeItems(bot, chestLocation, itemFilter) +} + +async function gotoChest (bot: Bot, location: Vec3, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(gotoChest)(bot, location) + await bot.pathfinder.goto(new goals.GoalGetToBlock(location.x, location.y, location.z)) +} + +async function placeItems (bot: Bot, chestPos: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise { + // @ts-expect-error + if (cb != null) return callbackify(placeItems)(bot, chestPos, itemFilter, cb) + const chestBlock = bot.blockAt(chestPos) + if (chestBlock == null) { + throw error('UnloadedChunk', 'Chest is in an unloaded chunk!') + } + const chest = await bot.openChest(chestBlock) + for (const item of bot.inventory.items()) { + if (!itemFilter(item)) continue + if (chest.firstEmptyContainerSlot() === null) { + // We have items that didn't fit. + return true + } + await chest.deposit(item.type, item.metadata, item.count) + } + return false +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Targets.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Targets.ts new file mode 100644 index 0000000000000000000000000000000000000000..568d07ad98ac8b4140344ed50515ad9e6a246899 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Targets.ts @@ -0,0 +1,60 @@ +import { Bot } from 'mineflayer' +import { Block } from 'prismarine-block' +import { Entity } from 'prismarine-entity' + +export type Collectable = Block | Entity + +export class Targets { + private readonly bot: Bot + private targets: Collectable[] = [] + + constructor (bot: Bot) { + this.bot = bot + } + + appendTargets (targets: Collectable[]): void { + for (const target of targets) { + this.appendTarget(target) + } + } + + appendTarget (target: Collectable): void { + if (this.targets.includes(target)) return + this.targets.push(target) + } + + /** + * Gets the closest target to the bot in this list. + * + * @returns The closest target, or null if there are no targets. + */ + getClosest (): Collectable | null { + let closest: Collectable | null = null + let distance: number = 0 + + for (const target of this.targets) { + const dist = target.position.distanceTo(this.bot.entity.position) + + if (closest == null || dist < distance) { + closest = target + distance = dist + } + } + + return closest + } + + get empty (): boolean { + return this.targets.length === 0 + } + + clear (): void { + this.targets.length = 0 + } + + removeTarget (target: Collectable): void { + const index = this.targets.indexOf(target) + if (index < 0) return + this.targets.splice(index, 1) + } +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/TaskQueue.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/TaskQueue.ts new file mode 100644 index 0000000000000000000000000000000000000000..81fe3bc5ae05d9f15eedbc4e9f307176ed819040 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/TaskQueue.ts @@ -0,0 +1,77 @@ +import type { Callback } from './index' +export type Task = (cb: Callback) => void +export type SyncTask = () => void + +/** + * A simple utility class for queuing up a series of async tasks to execute. + */ +export class TaskQueue { + private tasks: Task[] = [] + + /** + * If true, the task list will stop executing if one of the tasks throws an error. + */ + readonly stopOnError: boolean = true + + /** + * Adds a new async task to this queue. The provided callback should be executed when + * the async task is complete. + * + * @param task - The async task to add. + */ + add (task: Task): void { + this.tasks.push(task) + } + + /** + * Adds a synchronous task toi this queue. + * + * @param task - The sync task to add. + */ + addSync (task: SyncTask): void { + this.add((cb) => { + try { + task() + cb() + } catch (err: any) { + cb(err) + } + }) + } + + /** + * Runs all tasks currently in this queue and empties the queue. + * + * @param cb - The optional callback to be executed when all tasks in this queue have + * finished executing. + */ + runAll (cb?: Callback): void { + const taskList = this.tasks + this.tasks = [] + + let index = -1 + const runNext: () => void = () => { + index++ + if (index >= taskList.length) { + if (cb !== undefined) cb() + return + } + + try { + taskList[index]((err) => { + if (err !== undefined) { + if (cb !== undefined) cb(err) + + if (this.stopOnError) return + } + + runNext() + }) + } catch (err: any) { + if (cb !== undefined) cb(err) + } + } + + runNext() + } +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts new file mode 100644 index 0000000000000000000000000000000000000000..3f14a607da52bc42332ee1f6ba0999c2db76a679 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts @@ -0,0 +1,34 @@ +import { Bot } from 'mineflayer' + +class Subscription { + constructor (readonly eventName: string, readonly callback: Function) {} +} + +export class TemporarySubscriber { + private readonly subscriptions: Subscription[] = [] + + constructor (readonly bot: Bot) {} + + /** + * Adds a new temporary event listener to the bot. + * + * @param event - The event to subscribe to. + * @param callback - The function to execute. + */ + subscribeTo (event: string, callback: Function): void { + this.subscriptions.push(new Subscription(event, callback)) + + // @ts-expect-error + this.bot.on(event, callback) + } + + /** + * Removes all attached event listeners from the bot. + */ + cleanup (): void { + for (const sub of this.subscriptions) { + // @ts-expect-error + this.bot.removeListener(sub.eventName, sub.callback) + } + } +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Util.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Util.ts new file mode 100644 index 0000000000000000000000000000000000000000..ee0f29e0cb1034e1dd96593b73119382050b722b --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/Util.ts @@ -0,0 +1,13 @@ +/** + * Creates a new error object with the given type and message. + * + * @param type - The error type. + * @param message - The error message. + * + * @returns The error object. + */ +export function error (type: string, message: string): Error { + const e = new Error(message) + e.name = type + return e +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/index.ts b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..45c9a85087f56b5bd771477a6fe5b1a02d986b9f --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/src/index.ts @@ -0,0 +1,25 @@ +import { Bot } from 'mineflayer' +import { CollectBlock } from './CollectBlock' +import { pathfinder as pathfinderPlugin } from 'mineflayer-pathfinder' +import { plugin as toolPlugin } from 'mineflayer-tool' + +export function plugin (bot: Bot): void { + // @ts-expect-error + bot.collectBlock = new CollectBlock(bot) + + // Load plugins if not loaded manually. + setTimeout(() => loadPathfinderPlugin(bot), 0) + setTimeout(() => loadToolPlugin(bot), 0) +} + +function loadPathfinderPlugin (bot: Bot): void { + if (bot.pathfinder != null) return + bot.loadPlugin(pathfinderPlugin) +} + +function loadToolPlugin (bot: Bot): void { + if (bot.tool != null) return + bot.loadPlugin(toolPlugin) +} + +export { CollectBlock, Callback, CollectOptions } from './CollectBlock' diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/tsconfig.json b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..a6076bc0c72a5ed65fd375450a97b2feefb28045 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/mineflayer-collectblock/tsconfig.json @@ -0,0 +1,69 @@ +{ + "compilerOptions": { + /* Visit https://aka.ms/tsconfig.json to read more about this file */ + /* Basic Options */ + // "incremental": true, /* Enable incremental compilation */ + "target": "ES2015", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019', 'ES2020', or 'ESNEXT'. */ + "module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', 'es2020', or 'ESNext'. */ + // "lib": [], /* Specify library files to be included in the compilation. */ + "allowJs": true, /* Allow javascript files to be compiled. */ + "checkJs": true, /* Report errors in .js files. */ + // "jsx": "preserve", /* Specify JSX code generation: 'preserve', 'react-native', or 'react'. */ + "declaration": true, + // "declarationMap": true, /* Generates a sourcemap for each corresponding '.d.ts' file. */ + // "sourceMap": true, /* Generates corresponding '.map' file. */ + // "outFile": "./", /* Concatenate and emit output to single file. */ + "outDir": "./lib", + // "rootDir": "./", /* Specify the root directory of input files. Use to control the output directory structure with --outDir. */ + // "composite": true, /* Enable project compilation */ + // "tsBuildInfoFile": "./", /* Specify file to store incremental compilation information */ + // "removeComments": true, /* Do not emit comments to output. */ + // "noEmit": true, /* Do not emit outputs. */ + // "importHelpers": true, /* Import emit helpers from 'tslib'. */ + // "downlevelIteration": true, /* Provide full support for iterables in 'for-of', spread, and destructuring when targeting 'ES5' or 'ES3'. */ + // "isolatedModules": true, /* Transpile each file as a separate module (similar to 'ts.transpileModule'). */ + /* Strict Type-Checking Options */ + "strict": true, /* Enable all strict type-checking options. */ + // "noImplicitAny": true, /* Raise error on expressions and declarations with an implied 'any' type. */ + "strictNullChecks": true, /* Enable strict null checks. */ + // "strictFunctionTypes": true, /* Enable strict checking of function types. */ + // "strictBindCallApply": true, /* Enable strict 'bind', 'call', and 'apply' methods on functions. */ + // "strictPropertyInitialization": true, /* Enable strict checking of property initialization in classes. */ + // "noImplicitThis": true, /* Raise error on 'this' expressions with an implied 'any' type. */ + "alwaysStrict": true, /* Parse in strict mode and emit "use strict" for each source file. */ + /* Additional Checks */ + "noUnusedLocals": true, /* Report errors on unused locals. */ + // "noUnusedParameters": true, /* Report errors on unused parameters. */ + "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */ + // "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */ + /* Module Resolution Options */ + // "moduleResolution": "node", /* Specify module resolution strategy: 'node' (Node.js) or 'classic' (TypeScript pre-1.6). */ + // "baseUrl": "./", /* Base directory to resolve non-absolute module names. */ + // "paths": {}, /* A series of entries which re-map imports to lookup locations relative to the 'baseUrl'. */ + // "rootDirs": [], /* List of root folders whose combined content represents the structure of the project at runtime. */ + // "typeRoots": [], /* List of folders to include type definitions from. */ + // "types": [], /* Type declaration files to be included in compilation. */ + // "allowSyntheticDefaultImports": true, /* Allow default imports from modules with no default export. This does not affect code emit, just typechecking. */ + "esModuleInterop": true, /* Enables emit interoperability between CommonJS and ES Modules via creation of namespace objects for all imports. Implies 'allowSyntheticDefaultImports'. */ + // "preserveSymlinks": true, /* Do not resolve the real path of symlinks. */ + // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ + /* Source Map Options */ + // "sourceRoot": "", /* Specify the location where debugger should locate TypeScript files instead of source locations. */ + // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ + // "inlineSourceMap": true, /* Emit a single file with source maps instead of having a separate file. */ + // "inlineSources": true, /* Emit the source alongside the sourcemaps within a single file; requires '--inlineSourceMap' or '--sourceMap' to be set. */ + /* Experimental Options */ + // "experimentalDecorators": true, /* Enables experimental support for ES7 decorators. */ + // "emitDecoratorMetadata": true, /* Enables experimental support for emitting type metadata for decorators. */ + /* Advanced Options */ + "skipLibCheck": true, /* Skip type checking of declaration files. */ + "forceConsistentCasingInFileNames": true /* Disallow inconsistently-cased references to the same file. */ + }, + "include": [ + "src" + ], + "exclude": [ + "node_modules", + "**/__tests__/*" + ] +} \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/package.json b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/package.json new file mode 100644 index 0000000000000000000000000000000000000000..9e389d268c3e7e09d3ad36a9668fbac7ac587397 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/mineflayer/package.json @@ -0,0 +1,38 @@ +{ + "name": "voyager", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "dependencies": { + "body-parser": "^1.20.2", + "express": "^4.18.2", + "magic-string": "^0.30.0", + "minecraft-data": "^3.31.0", + "minecrafthawkeye": "^1.3.6", + "mineflayer": "^4.8.1", + "mineflayer-collectblock": "file:mineflayer-collectblock", + "mineflayer-pathfinder": "^2.4.2", + "mineflayer-pvp": "^1.3.2", + "mineflayer-tool": "^1.2.0", + "mocha": "^10.2.0", + "prismarine-biome": "^1.3.0", + "prismarine-block": "=1.16.3", + "prismarine-entity": "^2.2.0", + "prismarine-item": "^1.12.1", + "prismarine-nbt": "^2.2.1", + "prismarine-recipe": "^1.3.1", + "prismarine-viewer": "^1.24.0", + "typescript": "^4.9.5", + "vec3": "^0.1.8", + "graceful-fs": "^4.2.11" + }, + "devDependencies": { + "prettier": "2.8.5" + } +} diff --git a/notebook_dir/metagpt_yusin/environment/minecraft/process_monitor.py b/notebook_dir/metagpt_yusin/environment/minecraft/process_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..22d8ddcf2df18370a5d773c6e92d04539cbb9119 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/minecraft/process_monitor.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# refs to `voyager process_monitor.py` + +import re +import subprocess +import threading +import warnings +from typing import List + +import psutil + +from metagpt_yusin.logs import define_log_level + + +class SubprocessMonitor: + def __init__( + self, + commands: List[str], + name: str, + ready_match: str = r".*", + callback_match: str = r"^(?!x)x$", # regex that will never match + callback: callable = None, + finished_callback: callable = None, + ): + self.commands = commands + self.name = name + self.logger = define_log_level(name=name) + self.process = None + self.ready_match = ready_match + self.ready_event = None + self.ready_line = None + self.callback_match = callback_match + self.callback = callback + self.finished_callback = finished_callback + self.thread = None + + def _start(self): + self.logger.info(f"Starting subprocess with commands: {self.commands}") + + self.process = psutil.Popen( + self.commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + self.logger.info(f"Subprocess {self.name} started with PID {self.process.pid}.") + for line in iter(self.process.stdout.readline, ""): + self.logger.info(line.strip()) + if re.search(self.ready_match, line): + self.ready_line = line + self.logger.info("Subprocess is ready.") + self.ready_event.set() + if re.search(self.callback_match, line): + self.callback() + if not self.ready_event.is_set(): + self.ready_event.set() + warnings.warn(f"Subprocess {self.name} failed to start.") + if self.finished_callback: + self.finished_callback() + + def run(self): + self.ready_event = threading.Event() + self.ready_line = None + self.thread = threading.Thread(target=self._start) + self.thread.start() + self.ready_event.wait() + + def stop(self): + self.logger.info("Stopping subprocess.") + if self.process and self.process.is_running(): + self.process.terminate() + self.process.wait() + + @property + def is_running(self): + if self.process is None: + return False + return self.process.is_running() diff --git a/notebook_dir/metagpt_yusin/environment/software/__init__.py b/notebook_dir/metagpt_yusin/environment/software/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/software/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/environment/software/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/software/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a062c89b0f78f29434d022a3196c43be5d039d3c Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/software/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/software/__pycache__/software_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/software/__pycache__/software_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe50c63720f66c158631b78cd2f9108084007a38 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/software/__pycache__/software_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/software/software_env.py b/notebook_dir/metagpt_yusin/environment/software/software_env.py new file mode 100644 index 0000000000000000000000000000000000000000..15257c4ea848af1ee46dfae825a0a68dcfe75fb7 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/software/software_env.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Software Env + + +from metagpt_yusin.environment.base_env import Environment + + +class SoftwareEnv(Environment): + """a specific alias name""" + + pass diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/__init__.py b/notebook_dir/metagpt_yusin/environment/stanford_town/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/stanford_town/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a45f587842ea9f6ad7873fc5bb14d59ef6d2bd61 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/env_space.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/env_space.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c7b7718917b942406b3b3b5180fa94aea1b1c73 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/env_space.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/stanford_town_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/stanford_town_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9708c574467a7295b81ff4aa688c30942ae520c3 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/stanford_town_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/stanford_town_ext_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/stanford_town_ext_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0370fddc02a5c6fd894b8dcd88dd982cee89e255 Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/stanford_town/__pycache__/stanford_town_ext_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/env_space.py b/notebook_dir/metagpt_yusin/environment/stanford_town/env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..f0844d7d7aa99c6e90b42543897a0ae8f69a5920 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/stanford_town/env_space.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt_yusin.environment.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + ADD_TILE_EVENT = 1 # Add an event triple to a tile + RM_TILE_EVENT = 2 # Remove an event triple from a tile + TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle + RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile + + +class EnvAction(BaseEnvAction): + """env action type and its related params of action functions/apis""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate" + ) + subject: str = Field(default="", description="subject name of first element in event") + event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field( + default=["", None, None, None], description="tile event" + ) + + @field_validator("coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + """get part observation with specific params""" + + NONE = 0 # get whole observation from env + + GET_TITLE = 1 # get the tile detail dictionary with given tile coord + TILE_PATH = 2 # get the tile address with given tile coord + TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius + + +class EnvObsParams(BaseEnvObsParams): + """observation params for different EnvObsType""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate" + ) + level: str = Field(default="", description="different level of title") + vision_radius: int = Field(default=0, description="the vision radius of current tile") + + @field_validator("coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]] + + +def get_observation_space() -> spaces.Dict: + # it's a + space = spaces.Dict( + {"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)} + ) + + return space + + +def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict: + """The fields defined by the space correspond to the input parameters of the action except `action_type`""" + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64) + ), # coord of the tile + "subject": spaces.Text(256), # the first element of an tile event + "event": spaces.Tuple( + (spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256)) + ), # event is a tuple of four str + } + ) + return space diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/stanford_town_env.py b/notebook_dir/metagpt_yusin/environment/stanford_town/stanford_town_env.py new file mode 100644 index 0000000000000000000000000000000000000000..13ff68bb64bd4d5ab1dbba33c4df266419ef2f27 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/stanford_town/stanford_town_env.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG StanfordTown Env + +from metagpt_yusin.environment.base_env import Environment +from metagpt_yusin.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv + + +class StanfordTownEnv(StanfordTownExtEnv, Environment): + pass diff --git a/notebook_dir/metagpt_yusin/environment/stanford_town/stanford_town_ext_env.py b/notebook_dir/metagpt_yusin/environment/stanford_town/stanford_town_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..58d1d3200c65bb33997feb7ddafc4443003a3465 --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/stanford_town/stanford_town_ext_env.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The StanfordTown external environment to interate with the web interface +# refs to `generative_agents maze.py` + +import math +from pathlib import Path +from typing import Any, Optional + +from pydantic import ConfigDict, Field, model_validator + +from metagpt_yusin.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt_yusin.environment.stanford_town.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, + EnvObsValType, + get_action_space, + get_observation_space, +) +from metagpt_yusin.utils.common import read_csv_to_list, read_json_file + + +class StanfordTownExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + maze_asset_path: Optional[Path] = Field(default=None, description="the path to store maze assets") + maze_width: int = Field(default=140, description="maze map width") + maze_height: int = Field(default=100, description="maze map height") + sq_tile_size: int = Field(default=32, description="the pixel height/width of a tile") + special_constraint: str = Field( + default="", description="a string description of any relevant special constraints " "the world might have" + ) + tiles: list[list[dict]] = Field(default=[]) + address_tiles: dict[str, set] = Field(default=dict()) + collision_maze: list[list] = Field(default=[]) + + @model_validator(mode="before") + @classmethod + def _init_maze(cls, values): + maze_asset_path = values["maze_asset_path"] + assert maze_asset_path + maze_asset_path = Path(maze_asset_path) + + maze_matrix_path = maze_asset_path.joinpath("matrix") + meta_info = read_json_file(maze_matrix_path.joinpath("maze_meta_info.json")) + + maze_width = int(meta_info["maze_width"]) + maze_height = int(meta_info["maze_height"]) + values["maze_width"] = maze_width + values["maze_height"] = maze_height + values["sq_tile_size"] = int(meta_info["sq_tile_size"]) + values["special_constraint"] = meta_info["special_constraint"] + + # READING IN SPECIAL BLOCKS + # Special blocks are those that are colored in the Tiled map. + # Here is an example row for the arena block file: + # e.g, "25331, Double Studio, Studio, Bedroom 2, Painting" + + blocks_folder = maze_matrix_path.joinpath("special_blocks") + + _wb = blocks_folder.joinpath("world_blocks.csv") + wb_rows = read_csv_to_list(_wb, header=False) + wb = wb_rows[0][-1] + + _sb = blocks_folder.joinpath("sector_blocks.csv") + sb_rows = read_csv_to_list(_sb, header=False) + sb_dict = dict() + for i in sb_rows: + sb_dict[i[0]] = i[-1] + + _ab = blocks_folder.joinpath("arena_blocks.csv") + ab_rows = read_csv_to_list(_ab, header=False) + ab_dict = dict() + for i in ab_rows: + ab_dict[i[0]] = i[-1] + + _gob = blocks_folder.joinpath("game_object_blocks.csv") + gob_rows = read_csv_to_list(_gob, header=False) + gob_dict = dict() + for i in gob_rows: + gob_dict[i[0]] = i[-1] + + _slb = blocks_folder.joinpath("spawning_location_blocks.csv") + slb_rows = read_csv_to_list(_slb, header=False) + slb_dict = dict() + for i in slb_rows: + slb_dict[i[0]] = i[-1] + + # [SECTION 3] Reading in the matrices + # This is your typical two dimensional matrices. It's made up of 0s and + # the number that represents the color block from the blocks folder. + maze_folder = maze_matrix_path.joinpath("maze") + + _cm = maze_folder.joinpath("collision_maze.csv") + collision_maze_raw = read_csv_to_list(_cm, header=False)[0] + _sm = maze_folder.joinpath("sector_maze.csv") + sector_maze_raw = read_csv_to_list(_sm, header=False)[0] + _am = maze_folder.joinpath("arena_maze.csv") + arena_maze_raw = read_csv_to_list(_am, header=False)[0] + _gom = maze_folder.joinpath("game_object_maze.csv") + game_object_maze_raw = read_csv_to_list(_gom, header=False)[0] + _slm = maze_folder.joinpath("spawning_location_maze.csv") + spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0] + + # Loading the maze. The mazes are taken directly from the json exports of + # Tiled maps. They should be in csv format. + # Importantly, they are "not" in a 2-d matrix format -- they are single + # row matrices with the length of width x height of the maze. So we need + # to convert here. + # example format: [['0', '0', ... '25309', '0',...], ['0',...]...] + # 25309 is the collision bar number right now. + collision_maze = [] + sector_maze = [] + arena_maze = [] + game_object_maze = [] + spawning_location_maze = [] + for i in range(0, len(collision_maze_raw), maze_width): + tw = maze_width + collision_maze += [collision_maze_raw[i : i + tw]] + sector_maze += [sector_maze_raw[i : i + tw]] + arena_maze += [arena_maze_raw[i : i + tw]] + game_object_maze += [game_object_maze_raw[i : i + tw]] + spawning_location_maze += [spawning_location_maze_raw[i : i + tw]] + values["collision_maze"] = collision_maze + + tiles = [] + for i in range(maze_height): + row = [] + for j in range(maze_width): + tile_details = dict() + tile_details["world"] = wb + + tile_details["sector"] = "" + if sector_maze[i][j] in sb_dict: + tile_details["sector"] = sb_dict[sector_maze[i][j]] + + tile_details["arena"] = "" + if arena_maze[i][j] in ab_dict: + tile_details["arena"] = ab_dict[arena_maze[i][j]] + + tile_details["game_object"] = "" + if game_object_maze[i][j] in gob_dict: + tile_details["game_object"] = gob_dict[game_object_maze[i][j]] + + tile_details["spawning_location"] = "" + if spawning_location_maze[i][j] in slb_dict: + tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]] + + tile_details["collision"] = False + if collision_maze[i][j] != "0": + tile_details["collision"] = True + + tile_details["events"] = set() + + row += [tile_details] + tiles += [row] + values["tiles"] = tiles + + # Each game object occupies an event in the tile. We are setting up the + # default event value here. + for i in range(maze_height): + for j in range(maze_width): + if tiles[i][j]["game_object"]: + object_name = ":".join( + [tiles[i][j]["world"], tiles[i][j]["sector"], tiles[i][j]["arena"], tiles[i][j]["game_object"]] + ) + go_event = (object_name, None, None, None) + tiles[i][j]["events"].add(go_event) + + # Reverse tile access. + # -- given a string address, we return a set of all + # tile coordinates belonging to that address (this is opposite of + # tiles that give you the string address given a coordinate). This is + # an optimization component for finding paths for the personas' movement. + # address_tiles['bedroom-2-a'] == {(58, 9)} + # address_tiles['double studio:recreation:pool table'] + # == {(29, 14), (31, 11), (30, 14), (32, 11), ...}, + address_tiles = dict() + for i in range(maze_height): + for j in range(maze_width): + addresses = [] + if tiles[i][j]["sector"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}' + addresses += [add] + if tiles[i][j]["arena"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}:' + add += f'{tiles[i][j]["arena"]}' + addresses += [add] + if tiles[i][j]["game_object"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}:' + add += f'{tiles[i][j]["arena"]}:' + add += f'{tiles[i][j]["game_object"]}' + addresses += [add] + if tiles[i][j]["spawning_location"]: + add = f'{tiles[i][j]["spawning_location"]}' + addresses += [add] + + for add in addresses: + if add in address_tiles: + address_tiles[add].add((j, i)) + else: + address_tiles[add] = set([(j, i)]) + values["address_tiles"] = address_tiles + + values["action_space"] = get_action_space((maze_width, maze_height)) + values["observation_space"] = get_observation_space() + return values + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, EnvObsValType], dict[str, Any]]: + """reset env and get the init observation + Return results corresponding to `observation, info` + """ + super().reset(seed=seed, options=options) + + obs = self._get_obs() + + return obs, {} + + def _get_obs(self) -> dict[str, EnvObsValType]: + """Get observation""" + return { + "collision_maze": self.get_collision_maze(), + "tiles": self.tiles, + "address_tiles": self.get_address_tiles(), + } + + def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any: + """Get partial or full observation from the env""" + obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE + if obs_type == EnvObsType.NONE: + obs = self._get_obs() + elif obs_type == EnvObsType.GET_TITLE: + obs = self.access_tile(tile=obs_params.coord) + elif obs_type == EnvObsType.TILE_PATH: + obs = self.get_tile_path(tile=obs_params.coord, level=obs_params.level) + elif obs_type == EnvObsType.TILE_NBR: + obs = self.get_nearby_tiles(tile=obs_params.coord, vision_r=obs_params.vision_radius) + return obs + + def step(self, action: EnvAction) -> tuple[dict[str, EnvObsValType], float, bool, bool, dict[str, Any]]: + """Execute action and then return observation + Return results corresponding to `observation, reward, terminated, truncated, info` + """ + terminated = False + try: + self._execute_env_action(action) + except Exception: + terminated = True + + obs = self._get_obs() + + ret = (obs, 1.0, terminated, False, {}) + return ret + + def _execute_env_action(self, action: EnvAction): + action_type = action.action_type + if action_type == EnvActionType.NONE: + pass + elif action_type == EnvActionType.ADD_TILE_EVENT: + self.add_event_from_tile(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.RM_TILE_EVENT: + self.remove_event_from_tile(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.TURN_TILE_EVENT_IDLE: + self.turn_event_from_tile_idle(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.RM_TITLE_SUB_EVENT: + self.remove_subject_events_from_tile(subject=action.subject, tile=action.coord) + + def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: + """ + Turns a pixel coordinate to a tile coordinate. + """ + x = math.ceil(px_coordinate[0] / self.sq_tile_size) + y = math.ceil(px_coordinate[1] / self.sq_tile_size) + return x, y + + @mark_as_readable + def get_collision_maze(self) -> list: + return self.collision_maze + + @mark_as_readable + def get_address_tiles(self) -> dict: + return self.address_tiles + + @mark_as_readable + def access_tile(self, tile: tuple[int, int]) -> dict: + """ + Returns the tiles details dictionary that is stored in self.tiles of the + designated x, y location. + + INPUT + tile: The tile coordinate of our interest in (x, y) form. + OUTPUT + The tile detail dictionary for the designated tile. + EXAMPLE OUTPUT + Given (58, 9), + self.tiles[9][58] = {'world': 'double studio', + 'sector': 'double studio', 'arena': 'bedroom 2', + 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', + 'collision': False, + 'events': {('double studio:double studio:bedroom 2:bed', + None, None)}} + """ + x = tile[0] + y = tile[1] + return self.tiles[y][x] + + @mark_as_readable + def get_tile_path(self, tile: tuple[int, int], level: str) -> str: + """ + Get the tile string address given its coordinate. You designate the level + by giving it a string level description. + + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + level: world, sector, arena, or game object + OUTPUT + The string address for the tile. + EXAMPLE OUTPUT + Given tile=(58, 9), and level=arena, + "double studio:double studio:bedroom 2" + """ + x = tile[0] + y = tile[1] + tile = self.tiles[y][x] + + path = f"{tile['world']}" + if level == "world": + return path + else: + path += f":{tile['sector']}" + + if level == "sector": + return path + else: + path += f":{tile['arena']}" + + if level == "arena": + return path + else: + path += f":{tile['game_object']}" + + return path + + @mark_as_readable + def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]: + """ + Given the current tile and vision_r, return a list of tiles that are + within the radius. Note that this implementation looks at a square + boundary when determining what is within the radius. + i.e., for vision_r, returns x's. + x x x x x + x x x x x + x x P x x + x x x x x + x x x x x + + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + vision_r: The radius of the persona's vision. + OUTPUT: + nearby_tiles: a list of tiles that are within the radius. + """ + left_end = 0 + if tile[0] - vision_r > left_end: + left_end = tile[0] - vision_r + + right_end = self.maze_width - 1 + if tile[0] + vision_r + 1 < right_end: + right_end = tile[0] + vision_r + 1 + + bottom_end = self.maze_height - 1 + if tile[1] + vision_r + 1 < bottom_end: + bottom_end = tile[1] + vision_r + 1 + + top_end = 0 + if tile[1] - vision_r > top_end: + top_end = tile[1] - vision_r + + nearby_tiles = [] + for i in range(left_end, right_end): + for j in range(top_end, bottom_end): + nearby_tiles += [(i, j)] + return nearby_tiles + + @mark_as_writeable + def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """ + Add an event triple to a tile. + + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + self.tiles[tile[1]][tile[0]]["events"].add(curr_event) + + @mark_as_writeable + def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """dswaq + Remove an event triple from a tile. + + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + + @mark_as_writeable + def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + new_event = (event[0], None, None, None) + self.tiles[tile[1]][tile[0]]["events"].add(new_event) + + @mark_as_writeable + def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None: + """ + Remove an event triple that has the input subject from a tile. + + INPUT: + subject: "Isabella Rodriguez" + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event[0] == subject: + self.tiles[tile[1]][tile[0]]["events"].remove(event) diff --git a/notebook_dir/metagpt_yusin/environment/werewolf/__init__.py b/notebook_dir/metagpt_yusin/environment/werewolf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/werewolf/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/__init__.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba43d4889db479190f2dce497b0d9002186af8c Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/__init__.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/werewolf_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/werewolf_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb1c3cffeed2a3937444f7bed47124a00934a26b Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/werewolf_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/werewolf_ext_env.cpython-39.pyc b/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/werewolf_ext_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3cdfe58fd8597352c11051a68cb817c3eec61ca Binary files /dev/null and b/notebook_dir/metagpt_yusin/environment/werewolf/__pycache__/werewolf_ext_env.cpython-39.pyc differ diff --git a/notebook_dir/metagpt_yusin/environment/werewolf/werewolf_env.py b/notebook_dir/metagpt_yusin/environment/werewolf/werewolf_env.py new file mode 100644 index 0000000000000000000000000000000000000000..edfe5163bf977529bed9b50542b87a2d67cc039d --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/werewolf/werewolf_env.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Werewolf Env + +from pydantic import Field + +from metagpt_yusin.environment.base_env import Environment +from metagpt_yusin.environment.werewolf.werewolf_ext_env import WerewolfExtEnv +from metagpt_yusin.logs import logger +from metagpt_yusin.schema import Message + + +class WerewolfEnv(Environment, WerewolfExtEnv): + timestamp: int = Field(default=0) + + def publish_message(self, message: Message, add_timestamp: bool = True): + """Post information to the current environment""" + logger.debug(f"publish_message: {message.dump()}") + if add_timestamp: + # Because the content of the message may be repeated, for example, killing the same person in two nights + # Therefore, a unique timestamp prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory. + message.content = f"{self.timestamp} | " + message.content + self.memory.add(message) + self.history += f"\n{message}" + + async def run(self, k=1): + """Process all Role runs by order""" + for _ in range(k): + for role in self.roles.values(): + await role.run() + self.timestamp += 1 diff --git a/notebook_dir/metagpt_yusin/environment/werewolf/werewolf_ext_env.py b/notebook_dir/metagpt_yusin/environment/werewolf/werewolf_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..13bb5c9513c8bd4053cbc48415ff064c36d7542f --- /dev/null +++ b/notebook_dir/metagpt_yusin/environment/werewolf/werewolf_ext_env.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The werewolf game external environment to integrate with + +import random +from collections import Counter +from enum import Enum +from typing import Any, Callable, Optional + +from pydantic import ConfigDict, Field + +from metagpt_yusin.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt_yusin.environment.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt_yusin.logs import logger + + +class RoleState(Enum): + ALIVE = "alive" # the role is alive + KILLED = "killed" # the role is killed by werewolf or voting + POISONED = "poisoned" # the role is killed by posion + SAVED = "saved" # the role is saved by antidote + + +# the ordered rules by the moderator to announce to everyone each step +STEP_INSTRUCTIONS = { + 0: { + "content": "It’s dark, everyone close your eyes. I will talk with you/your team secretly at night.", + "send_to": "Moderator", # for moderator to continuen speaking + "restricted_to": "", + }, + 1: { + "content": "Guard, please open your eyes!", + "send_to": "Moderator", # for moderator to continuen speaking + "restricted_to": "", + }, + 2: { + "content": """Guard, now tell me who you protect tonight? + You only choose one from the following living options please: {living_players}. + Or you can pass. For example: Protect ...""", + "send_to": "Guard", + "restricted_to": "Moderator,Guard", + }, + 3: {"content": "Guard, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + 4: {"content": "Werewolves, please open your eyes!", "send_to": "Moderator", "restricted_to": ""}, + 5: { + "content": """Werewolves, I secretly tell you that {werewolf_players} are + all of the 2 werewolves! Keep in mind you are teammates. The rest players are not werewolves. + choose one from the following living options please: + {living_players}. For example: Kill ...""", + "send_to": "Werewolf", + "restricted_to": "Moderator,Werewolf", + }, + 6: {"content": "Werewolves, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + 7: {"content": "Witch, please open your eyes!", "send_to": "Moderator", "restricted_to": ""}, + 8: { + "content": """Witch, tonight {player_hunted} has been killed by the werewolves. + You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""", + "send_to": "Witch", + "restricted_to": "Moderator,Witch", + }, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人 + 9: { + "content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players? + Choose one from the following living options: {living_players}. + If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""", + "send_to": "Witch", + "restricted_to": "Moderator,Witch", + }, # + 10: {"content": "Witch, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + 11: {"content": "Seer, please open your eyes!", "send_to": "Moderator", "restricted_to": ""}, + 12: { + "content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight? + Choose only one from the following living options:{living_players}.""", + "send_to": "Seer", + "restricted_to": "Moderator,Seer", + }, + 13: {"content": "Seer, close your eyes", "send_to": "Moderator", "restricted_to": ""}, + # The 1-st daytime + 14: { + "content": """It's daytime. Everyone woke up except those who had been killed.""", + "send_to": "Moderator", + "restricted_to": "", + }, + 15: {"content": "{player_current_dead} was killed last night!", "send_to": "Moderator", "restricted_to": ""}, + 16: { + "content": """Living players: {living_players}, now freely talk about the current situation based on your observation and + reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""", + "send_to": "", # send to all to speak in daytime + "restricted_to": "", + }, + 17: { + "content": """Now vote and tell me who you think is the werewolf. Don’t mention your role. + You only choose one from the following living options please: + {living_players}. Say ONLY: I vote to eliminate ...""", + "send_to": "", + "restricted_to": "", + }, + 18: {"content": """{player_current_dead} was eliminated.""", "send_to": "Moderator", "restricted_to": ""}, +} + + +class WerewolfExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + players_state: dict[str, tuple[str, RoleState]] = Field( + default=dict(), description="the player's role type and state by player_name" + ) + + round_idx: int = Field(default=0) # the current round + step_idx: int = Field(default=0) # the current step of current round + eval_step_idx: int = Field(default=0) + per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS)) + + # game global states + game_setup: str = Field(default="", description="game setup including role and its num") + special_role_players: list[str] = Field(default=[]) + winner: Optional[str] = Field(default=None) + win_reason: Optional[str] = Field(default=None) + witch_poison_left: int = Field(default=1) + witch_antidote_left: int = Field(default=1) + + # game current round states, a round is from closing your eyes to the next time you close your eyes + round_hunts: dict[str, str] = Field(default=dict(), description="nighttime wolf hunt result") + round_votes: dict[str, str] = Field( + default=dict(), description="daytime all players vote result, key=voteer, value=voted one" + ) + player_hunted: Optional[str] = Field(default=None) + player_protected: Optional[str] = Field(default=None) + is_hunted_player_saved: bool = Field(default=False) + player_poisoned: Optional[str] = Field(default=None) + player_current_dead: list[str] = Field(default=[]) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + + @property + def living_players(self) -> list[str]: + player_names = [] + for name, roletype_state in self.players_state.items(): + if roletype_state[1] in [RoleState.ALIVE, RoleState.SAVED]: + player_names.append(name) + return player_names + + def _role_type_players(self, role_type: str) -> list[str]: + """return player name of particular role type""" + player_names = [] + for name, roletype_state in self.players_state.items(): + if role_type in roletype_state[0]: + player_names.append(name) + return player_names + + @property + def werewolf_players(self) -> list[str]: + player_names = self._role_type_players(role_type="Werewolf") + return player_names + + @property + def villager_players(self) -> list[str]: + player_names = self._role_type_players(role_type="Villager") + return player_names + + def _init_players_state(self, players: list["Role"]): + for play in players: + self.players_state[play.name] = (play.profile, RoleState.ALIVE) + + self.special_role_players = [ + p for p in self.living_players if p not in self.werewolf_players + self.villager_players + ] + + def init_game_setup( + self, + role_uniq_objs: list[object], + num_villager: int = 2, + num_werewolf: int = 2, + shuffle=True, + add_human=False, + use_reflection=True, + use_experience=False, + use_memory_selection=False, + new_experience_version="", + prepare_human_player=Callable, + ) -> tuple[str, list]: + """init players using different roles' num""" + role_objs = [] + for role_obj in role_uniq_objs: + if str(role_obj) == "Villager": + role_objs.extend([role_obj] * num_villager) + elif str(role_obj) == "Werewolf": + role_objs.extend([role_obj] * num_werewolf) + else: + role_objs.append(role_obj) + if shuffle: + random.shuffle(len(role_objs)) + if add_human: + assigned_role_idx = random.randint(0, len(role_objs) - 1) + assigned_role = role_objs[assigned_role_idx] + role_objs[assigned_role_idx] = prepare_human_player(assigned_role) # TODO + + players = [ + role( + name=f"Player{i + 1}", + use_reflection=use_reflection, + use_experience=use_experience, + use_memory_selection=use_memory_selection, + new_experience_version=new_experience_version, + ) + for i, role in enumerate(role_objs) + ] + + if add_human: + logger.info(f"You are assigned {players[assigned_role_idx].name}({players[assigned_role_idx].profile})") + + game_setup = ["Game setup:"] + [f"{player.name}: {player.profile}," for player in players] + self.game_setup = "\n".join(game_setup) + + self._init_players_state(players) # init players state + + return self.game_setup, players + + def _update_players_state(self, player_names: list[str], state: RoleState = RoleState.KILLED): + for player_name in player_names: + if player_name in self.players_state: + roletype_state = self.players_state[player_name] + self.players_state[player_name] = (roletype_state[0], state) + + def _check_valid_role(self, player: "Role", role_type: str) -> bool: + return True if role_type in str(player) else False + + def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool: + step_idx = self.step_idx % self.per_round_steps + if particular_step > 0 and step_idx != particular_step: # step no + # particular_step = 18, not daytime vote time, ignore + # particular_step = 15, not nighttime hunt time, ignore + return False + if player_name not in self.living_players: + return False + return True + + @mark_as_readable + def curr_step_instruction(self) -> dict: + step_idx = self.step_idx % len(STEP_INSTRUCTIONS) + instruction = STEP_INSTRUCTIONS[step_idx] + self.step_idx += 1 + return instruction + + @mark_as_readable + def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]: + players_state = { + player_name: self.players_state[player_name][1] # only return role state + for player_name in player_names + if player_name in self.players_state + } + return players_state + + @mark_as_writeable + def vote_kill_someone(self, voteer: "Role", player_name: str = None): + """player vote result at daytime + player_name: if it's None, regard as abstaining from voting + """ + if not self._check_player_continue(voteer.name, particular_step=18): # 18=step no + return + + self.round_votes[voteer.name] = player_name + # check if all living players finish voting, then get the dead one + if list(self.round_votes.keys()) == self.living_players: + voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first + voted_all = [item for item in voted_all if item] + self.player_current_dead = Counter(voted_all).most_common()[0][0] + self._update_players_state([self.player_current_dead]) + + @mark_as_writeable + def wolf_kill_someone(self, wolf: "Role", player_name: str): + if not self._check_valid_role(wolf, "Werewolf"): + return + if not self._check_player_continue(wolf.name, particular_step=5): # 5=step no + return + + self.round_hunts[wolf.name] = player_name + living_werewolf = [p for p in self.werewolf_players if p in self.living_players] + # check if all living wolfs finish hunting, then get the hunted one + if list(self.round_hunts.keys()) == living_werewolf: + hunted_all = list(self.round_hunts.values()) + self.player_hunted = Counter(hunted_all).most_common()[0][0] + + @mark_as_writeable + def witch_poison_someone(self, witch: "Role", player_name: str = None): + if not self._check_valid_role(witch, "Witch"): + return + if not self._check_player_continue(player_name): + return + + self._update_players_state([player_name], RoleState.POISONED) + self.player_poisoned = player_name + + @mark_as_writeable + def witch_save_someone(self, witch: "Role", player_name: str = None): + if not self._check_valid_role(witch, "Witch"): + return + if not self._check_player_continue(player_name): + return + + self._update_players_state([player_name], RoleState.SAVED) + self.player_protected = player_name + + @mark_as_writeable + def update_game_states(self, memories: list): + step_idx = self.step_idx % self.per_round_steps + if step_idx not in [15, 18] or self.step_idx in self.eval_step_idx: + return + else: + self.eval_step_idx.append(self.step_idx) # record evaluation, avoid repetitive evaluation at the same step + + if step_idx == 15: # step no + # night ends: after all special roles acted, process the whole night + self.player_current_dead = [] # reset + + if self.player_hunted != self.player_protected and not self.is_hunted_player_saved: + self.player_current_dead.append(self.player_hunted) + if self.player_poisoned: + self.player_current_dead.append(self.player_poisoned) + + self._update_players_state([self.player_current_dead]) + # reset + self.player_hunted = None + self.player_protected = None + self.is_hunted_player_saved = False + self.player_poisoned = None + + # game's termination condition + living_werewolf = [p for p in self.werewolf_players if p in self.living_players] + living_villagers = [p for p in self.villager_players if p in self.living_players] + living_special_roles = [p for p in self.special_role_players if p in self.living_players] + if not living_werewolf: + self.winner = "good guys" + self.win_reason = "werewolves all dead" + elif not living_villagers or not living_special_roles: + self.winner = "werewolf" + self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead" + if self.winner is not None: + self._record_all_experiences() # TODO diff --git a/notebook_dir/metagpt_yusin/ext/__init__.py b/notebook_dir/metagpt_yusin/ext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/README.md b/notebook_dir/metagpt_yusin/ext/android_assistant/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fe8b4b3e32c9dded2fc82edd22792ff1d1ab5a4b --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/README.md @@ -0,0 +1,118 @@ +# MetaGPT Android Assistant + +The MetaGPT Android Assistant is an intelligent assistance tool driven by a multi-modal large language model based on the advanced MetaGPT framework. It has the ability to self-learn, mastering users' daily usage patterns through learning, and can automatically complete various application operations according to user instructions, achieving comprehensive liberation of users' hands. +Next, we will introduce the functions of the MetaGPT Android Assistant and how to use it. + +## Features + +The operation of the MetaGPT Android Assistant mainly includes two stages: learning and automatic execution. Below, we introduce the specific features of the MetaGPT Android Assistant from these two stages. + +### Learning Stage + +By learning from human demonstrations or exploring apps based on human instructions, the MetaGPT Android Assistant can learn the functionality of apps, generate corresponding operation documents for use in the subsequent "automatic execution" stage. Approximately 20 rounds of exploration for any given task objective can significantly improve performance. + +By setting the `stage` to `learn`, you can ask the Android Assistant to enter the learning stage. By setting the `mode` to `auto`, you can instruct the Android Assistant to learn through automatic exploration; by setting the mode to manual, you can instruct the Android Assistant to learn through human manual demonstration. In the usage section, we provide detailed explanations of the script parameters. You can try experimenting with automatic exploration and manual demonstration modes on the "Messenger" app with the following commands: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "learn" --mode "auto or manual" --app-name "Messenger" +``` + +#### Learning Based on Human Demonstration +When asking the Android Assistant to perform self-exploration during the learning stage, you can free your hands. However, when instructing it to learn according to your commands, you need to follow the instructions in the terminal for the Android Assistant to accurately learn your operation methods. +A possible example is as follows: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "learn" --mode "manual" --app-name "Messenger" +``` + +After running this command, you will first see a screenshot of an Android screen that has been marked at various interactive locations, as shown in the figure below: + + + +After remembering the location where you want to operate, a request similar to the one below will be output in the terminal. Reply to it and thereby direct the Android assistant to learn your demonstration action: + +```bash +| INFO | examples.android_assistant.actions.manual_record:run:96 - Which element do you want to tap? Choose a numeric tag from 1 to 11: +user_input: 8 +| INFO | examples.android_assistant.actions.manual_record:run:81 - Choose one of the following actions you want to perform on the current screen: +tap, text, long_press, swipe, stop +user_input: tap +``` + +### Automatic Execution Stage +After the Android Assistant completes the learning stage, you can command it to complete tasks on the phone through text descriptions. By configuring the operation documents from the self-learning stage, the Android Assistant has richer prior knowledge, and its execution capabilities are further enhanced. +You can instruct the Android Assistant to send messages in the "Messenger" app with the following command: +```bash +python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "act" --mode "auto or manual" --app-name "Messenger" +``` +Specifically, by selecting `auto` for `mode`, the Android assistant will employ the operational records compiled through self-exploration. Alternatively, if `manual` is chosen as the `mode`, the Android assistant will leverage the operation manuals accrued from learning via human demonstration. + +## Installation +To use the Android Assistant, you first need to meet the following conditions: +1. Complete the installation of the MetaGPT environment. +2. Install [Android Debug Bridge (ADB)](https://developer.android.com/tools/adb?hl=zh-cn) on your PC, which enables interaction between your PC and Android devices. +3. Install Android Studio and within it, install the Android emulator to provide an environment for the Android Assistant to learn and execute. For information on how to install the Android emulator, refer to [Quick Installation of Android Studio & Emulator](https://docs.expo.dev/workflow/android-studio-emulator/). +4. (Optional) Connect your Android device to the USB port of your PC, which can also provide an environment for the Android Assistant to learn and execute. + +Note ⚠️: When operating with the Android emulator, the emulator model we use is Medium Phone, which is recommended for first-time users to complete the operation. + +After completing these operations, you can enter the following command to check if ADB is installed successfully and if the Android device is connected: +```bash +adb devices +``` + +## Usage +The MetaGPT Android Assistant is designed within the MetaGPT framework as a collection of Roles and multiple Actions. You can run it by executing the `run_assistant.py` script. The specific parameter description of this script is as follows: +```text +Usage: run_assistant.py [OPTIONS] TASK_DESC + + Run a Android Assistant + +Arguments: + TASK_DESC the task description you want the android assistant to learn or + act [required] + +Options: + --n-round INTEGER The max round to do an app operation task. + [default: 20] + --stage TEXT stage: learn / act [default: learn] + --mode TEXT mode: auto / manual , when state=learn + [default: auto] + --app-name TEXT the name of app you want to run [default: + demo] + --investment FLOAT Dollar amount to invest in the AI company. + [default: 5.0] + --refine-doc / --no-refine-doc Refine existing operation docs based on the + latest observation if True. [default: no- + refine-doc] + --min-dist INTEGER The minimum distance between elements to + prevent overlapping during the labeling + process. [default: 30] + --android-screenshot-dir TEXT The path to store screenshots on android + device. Make sure it exists. [default: + /sdcard/Pictures/Screenshots] + --android-xml-dir TEXT The path to store xml files for determining + UI elements localtion. Make sure it exists. + [default: /sdcard] + --device-id TEXT The Android device_id [default: + emulator-5554] + --help Show this message and exit. +``` + +## Acknowledgements +The MetaGPT Android Assistant has referenced some ideas and code from the [AppAgent](https://github.com/mnotgod96/AppAgent) project. We thank the developers of the Appagent project. + +### Citation + +```bib +@misc{yang2023appagent, + title={AppAgent: Multimodal Agents as Smartphone Users}, + author={Chi Zhang and Zhao Yang and Jiaxuan Liu and Yucheng Han and Xin Chen and Zebiao Huang and Bin Fu and Gang Yu}, + year={2023}, + eprint={2312.13771}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/README_CN.md b/notebook_dir/metagpt_yusin/ext/android_assistant/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..a1abbe3b0bfa3b61bd76e15a11d71f9c43281190 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/README_CN.md @@ -0,0 +1,113 @@ +# MetaGPT 安卓助理 + +MetaGPT安卓助理是一款依托于先进的MetaGPT框架构建的多模态大语言模型驱动的智能辅助工具。 +它具备自我学习的能力,能够通过学习掌握用户的日常使用方式,同时能够根据用户的指令自动完成各类应用程序的操作任务,实现了用户双手的全面解放。 +接下来,我们将介绍MetaGPT安卓助理的功能以及如何使用它。 + +## 功能 + +MetaGPT 安卓助理的执行主要包含两个阶段,分别为自我学习与自动执行。下面,我们将从这两个阶段介绍MetaGPT 安卓助理的具体功能。 + +### 自我学习阶段 + +通过学习人类演示或基于人类指令对app进行探索,MetaGPT安卓助理可以对app的功能进行学习,生成相应的操作文档,为后续的“自动执行”阶段使用。对于任何给定的任务目标,进行约20轮的探索可以显著提高性能。 + +通过设定`stage`为`learn`可要求安卓助理进入自我学习阶段。通过设定`mode`为`auto`,可要求安卓助理通过自动探索学习,通过设定`mode`为`manual`,可要求安卓助理通过人类手动演示学习。在使用章节,我们对脚本的参数进行了详细的说明。 +您可以尝试对“Messenger”应用程序进行自动探索和手动演示模式的实验,具体命令如下: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "learn" --mode "auto or manual" --app-name "Messenger" +``` + +#### 基于人类演示的学习 +在要求安卓助理在自我学习阶段执行自我探索时,您可以解放您的双手,但在要求他根据您的指令进行学习时,你需要根据终端中的指令进行输入,以便安卓助理能够准确地学习您的操作方式。 +一个可能的例子如下: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "learn" --mode "manual" --app-name "Messenger" +``` + +在运行这一指令后,你将首先看到一个在各个可交互的位置进行了标记的安卓屏幕的截图,如下图: + + + +在记住你要操作的位置之后,终端中将会输出与下面类似的要求,回复它,进而指挥安卓助理学习你的演示行为: + +```bash +| INFO | examples.android_assistant.actions.manual_record:run:96 - Which element do you want to tap? Choose a numeric tag from 1 to 11: +user_input: 8 +| INFO | examples.android_assistant.actions.manual_record:run:81 - Choose one of the following actions you want to perform on the current screen: +tap, text, long_press, swipe, stop +user_input: tap +``` +### 自动执行阶段 +在安卓助理完成了自我学习阶段之后,您可以通过文本描述的方式,指挥安卓助理在手机中完成任务。通过为其配置自我学习阶段的操作文档,安卓助理具备了更丰富的前置知识,执行能力进一步得到提升。 +你可以通过以下指令,指挥安卓助理在“Messenger”应用中发送信息: +```bash +python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "act" --mode "auto or manual" --app-name "Messenger" +``` +其中,`mode`选择`auto`,安卓助理将使用自我探索中积累的操作文档;`mode`选择`manual`,安卓助理将使用人类演示学习中积累的操作文档。 + +## 安装 +为了使用安卓助理,你首先需要满足以下条件: +1. 完成MetaGPT环境的安装 +2. 在你的PC上安装[Android Debug Bridge(ADB)](https://developer.android.com/tools/adb?hl=zh-cn),ADB可以使你的PC与安卓设备进行交互。 +3. 安装Android Studio,在其中安装Android模拟器,以为安卓助手提供学习与执行的环境。关于如何安装Android模拟器,可以参考[快速安装Android Studio & Emulator](https://dev.weixin.qq.com/docs/framework/dev/framework/env/android-simulator.html)。 +4. (Optional) 将你的安卓设备连接到PC的USB端口上,这同样可以为安卓助手提供学习与执行的环境。 + +注意 ⚠️:在使用Android模拟器进行操作时,我们使用的模拟器型号为Medium Phone,建议第一次尝试此类应用的用户使用这一型号完成操作。 + +在完成这一系列操作之后,你可以输入以下命令检查ADB是否安装成功,以及安卓设备是否连接 +```bash +adb devices +``` +## 使用 +MetaGPT 安卓助理在MetaGPT框架中被设计为一个`Role`与多个`Action`的集合,你可以通过运行`run_assistant.py`脚本来运行它。这一脚本具体的参数说明如下: +```text +用法:run_assistant.py [选项] 任务描述 + + 运行一个安卓助手 + +参数: + TASK_DESC 你希望安卓助手学习或执行的任务描述 + [必需] + +选项: + --n-round 整数 执行应用程序操作任务的最大轮数。 + [默认值:20] + --stage 文本 阶段:learn/act [默认值:learn] + --mode 文本 模式:auto/manual,当状态=learn时 [默认值:auto] + --app-name 文本 你想要运行的应用程序名称 [默认值: + 演示] + --investment 浮点数 投资于人工智能公司的美元金额。 + [默认值:5.0] + --refine-doc / --no-refine-doc 如果为真,则根据最新的观察结果优化现有操作文档。 + [默认值:--no-refine-doc] + --min-dist 整数 在标记过程中防止元素重叠的最小元素间距。 + [默认值:30] + --android-screenshot-dir 文本 在安卓设备上存储截图的路径。确保其存在。 + [默认值:/sdcard/Pictures/Screenshots] + --android-xml-dir 文本 存储用于确定UI元素位置的XML文件的路径。 + 确保其存在。[默认值:/sdcard] + --device-id 文本 安卓device_id [默认值: + 模拟器-5554] + --help 显示此信息并退出。 +``` + +## 致谢 +MetaGPT 安卓助理参考了 [AppAgent](https://github.com/mnotgod96/AppAgent) 项目的部分思路与代码,感谢 Appagent 项目的开发者们。 + +### 引用 + +```bib +@misc{yang2023appagent, + title={AppAgent: Multimodal Agents as Smartphone Users}, + author={Chi Zhang and Zhao Yang and Jiaxuan Liu and Yucheng Han and Xin Chen and Zebiao Huang and Bin Fu and Gang Yu}, + year={2023}, + eprint={2312.13771}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/__init__.py b/notebook_dir/metagpt_yusin/ext/android_assistant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/__init__.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/manual_record.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/manual_record.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfb2ed893ae259b3401c890912414461f3cff5e --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/manual_record.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : manual record user interaction in stage=learn & mode=manual, LIKE scripts/step_recorder.py +import time +from pathlib import Path + +import cv2 + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidActionOutput, + RunState, + SwipeOp, +) +from metagpt.ext.android_assistant.utils.utils import ( + draw_bbox_multi, + elem_list_from_xml_tree, +) +from metagpt.logs import logger + + +class ManualRecord(Action): + """do a human operation on the screen with human input""" + + name: str = "ManualRecord" + + useless_list: list[str] = [] # store useless elements uid + record_path: Path = "" + task_desc_path: Path = "" + screenshot_before_path: Path = "" + screenshot_after_path: Path = "" + xml_path: Path = "" + + async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv): + self.record_path = Path(task_dir) / "record.txt" + self.task_desc_path = Path(task_dir) / "task_desc.txt" + self.screenshot_before_path = Path(task_dir) / "raw_screenshots" + self.screenshot_after_path = Path(task_dir) / "labeled_screenshots" + self.xml_path = Path(task_dir) / "xml" + for path in [self.screenshot_before_path, self.screenshot_after_path, self.xml_path]: + path.mkdir(parents=True, exist_ok=True) + + self.record_path.write_text("") + record_file = open(self.record_path, "w") + self.task_desc_path.write_text(task_desc) + + step = 0 + extra_config = config.extra + while True: + step += 1 + screenshot_path: Path = env.observe( + EnvObsParams( + obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{step}", local_save_dir=self.screenshot_before_path + ) + ) + xml_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{step}", local_save_dir=self.xml_path) + ) + if not screenshot_path.exists() or not xml_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + elem_list = elem_list_from_xml_tree(xml_path, self.useless_list, extra_config.get("min_dist", 30)) + + screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{step}_labeled.png") + labeled_img = draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list) + + cv2.namedWindow("image", cv2.WINDOW_NORMAL) + cv2.imshow("image", labeled_img) + cv2.waitKey(0) + cv2.destroyAllWindows() + + user_input = "xxx" + logger.info( + "Choose one of the following actions you want to perform on the current screen:\n" + "tap, text, long_press, swipe, stop" + ) + + while ( + user_input.lower() != ActionOp.TAP.value + and user_input.lower() != ActionOp.TEXT.value + and user_input.lower() != ActionOp.LONG_PRESS.value + and user_input.lower() != ActionOp.SWIPE.value + and user_input.lower() != ActionOp.STOP.value + ): + user_input = input("user_input: ") + + if user_input.lower() == ActionOp.TAP.value: + logger.info(f"Which element do you want to tap? Choose a numeric tag from 1 to {len(elem_list)}:") + user_input = "xxx" + while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1: + user_input = input("user_input: ") + tl, br = elem_list[int(user_input) - 1].bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + log_str = f"tap({int(user_input)}):::{elem_list[int(user_input) - 1].uid}\n" + elif user_input.lower() == ActionOp.TEXT.value: + logger.info( + f"Which element do you want to input the text string? Choose a numeric tag from 1 to " + f"{len(elem_list)}:" + ) + input_area = "xxx" + while not input_area.isnumeric() or int(input_area) > len(elem_list) or int(input_area) < 1: + input_area = input("user_input: ") + logger.info("Enter your input text below:") + user_input = "" + while not user_input: + user_input = input("user_input: ") + action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=user_input) + log_str = f"text({input_area}:sep:'{user_input}'):::{elem_list[int(input_area) - 1].uid}\n" + elif user_input.lower() == ActionOp.LONG_PRESS.value: + logger.info( + f"Which element do you want to long press? Choose a numeric tag from 1 to {len(elem_list)}:" + ) + user_input = "xxx" + while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1: + user_input = input("user_input: ") + tl, br = elem_list[int(user_input) - 1].bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + log_str = f"long_press({int(user_input)}):::{elem_list[int(user_input) - 1].uid}\n" + elif user_input.lower() == ActionOp.SWIPE.value: + logger.info( + "What is the direction of your swipe? Choose one from the following options:\n" + "up, down, left, right" + ) + user_input = "" + while ( + user_input != SwipeOp.UP.value + and user_input != SwipeOp.DOWN.value + and user_input != SwipeOp.LEFT.value + and user_input != SwipeOp.RIGHT.value + ): + user_input = input("user_input: ") + swipe_dir = user_input + logger.info(f"Which element do you want to swipe? Choose a numeric tag from 1 to {len(elem_list)}:") + while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1: + user_input = input("user_input: ") + tl, br = elem_list[int(user_input) - 1].bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + + action = EnvAction(action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=swipe_dir) + log_str = f"swipe({int(user_input)}:sep:{swipe_dir}):::{elem_list[int(user_input) - 1].uid}\n" + elif user_input.lower() == ActionOp.STOP.value: + record_file.write("stop\n") + record_file.close() + break + else: + break + + obs, _, _, _, info = env.step(action) + action_res = info["res"] + if action_res == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + record_file.write(log_str) + + time.sleep(1) + + return AndroidActionOutput(action_state=RunState.SUCCESS) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/parse_record.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/parse_record.py new file mode 100644 index 0000000000000000000000000000000000000000..304daf65563281d45af85864b6a252be7947a4b9 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/parse_record.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual, +# LIKE scripts/document_generation.py + +import ast +import re +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.ext.android_assistant.actions.parse_record_an import RECORD_PARSE_NODE +from metagpt.ext.android_assistant.prompts.operation_prompt import ( + long_press_doc_template, + refine_doc_suffix, + swipe_doc_template, + tap_doc_template, + text_doc_template, +) +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidActionOutput, + RecordLogItem, + RunState, + SwipeOp, +) +from metagpt.logs import logger +from metagpt.utils.common import encode_image + + +class ParseRecord(Action): + name: str = "ParseRecord" + record_path: Path = "" + task_desc_path: Path = "" + screenshot_before_path: Path = "" + screenshot_after_path: Path = "" + + async def run(self, task_dir: Path, docs_dir: Path): + doc_count = 0 + self.record_path = Path(task_dir) / "record.txt" + self.task_desc_path = Path(task_dir) / "task_desc.txt" + self.screenshot_before_path = Path(task_dir) / "raw_screenshots" + self.screenshot_after_path = Path(task_dir) / "labeled_screenshots" + for path in [self.screenshot_before_path, self.screenshot_after_path]: + path.mkdir(parents=True, exist_ok=True) + + task_desc = self.task_desc_path.read_text() + extra_config = config.extra + + with open(self.record_path, "r") as record_file: + record_step_count = len(record_file.readlines()) - 1 + record_file.seek(0) + for step in range(1, record_step_count + 1): + img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step}_labeled.png")) + img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step + 1}_labeled.png")) + rec = record_file.readline().strip() + action, resource_id = rec.split(":::") + action_type = action.split("(")[0] + # 构建Prompt + action_param = re.findall(r"\((.*?)\)", action)[0] + if action_type == ActionOp.TAP.value: + prompt_template = tap_doc_template + context = prompt_template.format(ui_element=action_param) + elif action_type == ActionOp.TEXT.value: + input_area, input_text = action_param.split(":sep:") + prompt_template = text_doc_template + context = prompt_template.format(ui_element=input_area) + elif action_type == ActionOp.LONG_PRESS.value: + prompt_template = long_press_doc_template + context = prompt_template.format(ui_element=action_param) + elif action_type == ActionOp.SWIPE.value: + swipe_area, swipe_dir = action_param.split(":sep:") + if swipe_dir == SwipeOp.UP.value or swipe_dir == SwipeOp.DOWN.value: + action_type = ActionOp.VERTICAL_SWIPE.value + elif swipe_dir == SwipeOp.LEFT.value or swipe_dir == SwipeOp.RIGHT.value: + action_type = ActionOp.HORIZONTAL_SWIPE.value + prompt_template = swipe_doc_template + context = prompt_template.format(swipe_dir=swipe_dir, ui_element=swipe_area) + else: + break + context = context.format(task_desc=task_desc) + + doc_name = resource_id + ".txt" + doc_path = docs_dir.joinpath(doc_name) + + if doc_path.exists(): + try: + doc_content = ast.literal_eval(doc_path.read_text()) + except Exception as exp: + logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}") + continue + + if doc_content[action_type]: + if extra_config.get("doc_refine", False): + refine_context = refine_doc_suffix.format(old_doc=doc_content[action_type]) + context += refine_context + logger.info( + f"Documentation for the element {resource_id} already exists. The doc will be " + f"refined based on the latest demo." + ) + else: + logger.info( + f"Documentation for the element {resource_id} already exists. Turn on DOC_REFINE " + f"in the config file if needed." + ) + continue + else: + doc_content = {"tap": "", "text": "", "v_swipe": "", "h_swipe": "", "long_press": ""} + + logger.info(f"Waiting for GPT-4V to generate documentation for the element {resource_id}") + node = await RECORD_PARSE_NODE.fill( + context=context, llm=self.llm, images=[img_before_base64, img_after_base64] + ) + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + log_path = task_dir.joinpath("log_parse_record.txt") + prompt = node.compile(context=context, schema="json", mode="auto") + msg = node.content + doc_content[action_type] = msg + + with open(log_path, "a") as logfile: + log_item = RecordLogItem( + step=step, + prompt=prompt, + image_before=img_before_base64, + image_after=img_after_base64, + response=node.content, + ) + logfile.write(log_item.model_dump_json() + "\n") + with open(doc_path, "w") as outfile: + outfile.write(str(doc_content)) + doc_count += 1 + logger.info(f"Documentation generated and saved to {doc_path}") + + logger.info(f"Documentation generation phase completed. {doc_count} docs generated.") + + return AndroidActionOutput(action_state=RunState.FINISH) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/parse_record_an.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/parse_record_an.py new file mode 100644 index 0000000000000000000000000000000000000000..210c93e236db761163dd8b5788ae46f209d58c0c --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/parse_record_an.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the ActionNode to parse record + +from metagpt.actions.action_node import ActionNode + +OBSERVATION = ActionNode( + key="Observation", + expected_type=str, + instruction="Provide a description of your observations of the two images. " + "Subsequently, delineate the distinctions between the first image and the second one.", + example="", +) + +THOUGHT = ActionNode( + key="Thought", + expected_type=str, + instruction="Consider the impact of Action acting on UI elements.", + example="", +) + +DESCRIPTION = ActionNode( + key="Description", + expected_type=str, + instruction="Describe the functionality of the UI element concisely in one or two sentences Do not include " + "the numeric tag in your description", + example="", +) + +NODES = [OBSERVATION, THOUGHT, DESCRIPTION] + +RECORD_PARSE_NODE = ActionNode.from_children("RecordParse", NODES) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/screenshot_parse.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/screenshot_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8bb0e1eb993cd285dd9e807da81638baa1a3fb --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/screenshot_parse.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : LIKE scripts/task_executor.py in stage=act + +import ast +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) +from metagpt.ext.android_assistant.actions.screenshot_parse_an import ( + SCREENSHOT_PARSE_NODE, +) +from metagpt.ext.android_assistant.prompts.assistant_prompt import ( + screenshot_parse_template, + screenshot_parse_with_grid_template, +) +from metagpt.ext.android_assistant.utils.schema import ( + AndroidActionOutput, + AndroidElement, + GridOpParam, + LongPressGridOpParam, + LongPressOpParam, + OpLogItem, + RunState, + SwipeGridOpParam, + SwipeOpParam, + TapGridOpParam, + TapOpParam, + TextOpParam, +) +from metagpt.ext.android_assistant.utils.utils import ( + area_to_xy, + draw_bbox_multi, + draw_grid, + elem_bbox_to_xy, + screenshot_parse_extract, + traverse_xml_tree, +) +from metagpt.logs import logger +from metagpt.utils.common import encode_image + + +class ScreenshotParse(Action): + name: str = "ScreenshotParse" + + def _makeup_ui_document(self, elem_list: list[AndroidElement], docs_idr: Path, use_exist_doc: bool = True) -> str: + if not use_exist_doc: + return "" + + ui_doc = """ +You also have access to the following documentations that describes the functionalities of UI +elements you can interact on the screen. These docs are crucial for you to determine the target of your +next action. You should always prioritize these documented elements for interaction: """ + for i, elem in enumerate(elem_list): + doc_path = docs_idr.joinpath(f"{elem.uid}.txt") + if not doc_path.exists(): + continue + try: + doc_content = ast.literal_eval(doc_path.read_text()) + except Exception as exp: + logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}") + continue + + ui_doc += f"Documentation of UI element labeled with the numeric tag '{i + 1}':\n" + if doc_content["tap"]: + ui_doc += f"This UI element is clickable. {doc_content['tap']}\n\n" + if doc_content["text"]: + ui_doc += ( + f"This UI element can receive text input. The text input is used for the following " + f"purposes: {doc_content['text']}\n\n" + ) + if doc_content["long_press"]: + ui_doc += f"This UI element is long clickable. {doc_content['long_press']}\n\n" + if doc_content["v_swipe"]: + ui_doc += ( + f"This element can be swiped directly without tapping. You can swipe vertically on " + f"this UI element. {doc_content['v_swipe']}\n\n" + ) + if doc_content["h_swipe"]: + ui_doc += ( + f"This element can be swiped directly without tapping. You can swipe horizontally on " + f"this UI element. {doc_content['h_swipe']}\n\n" + ) + return ui_doc + + async def run( + self, + round_count: int, + task_desc: str, + last_act: str, + task_dir: Path, + docs_dir: Path, + grid_on: bool, + env: AndroidEnv, + ): + extra_config = config.extra + for path in [task_dir, docs_dir]: + path.mkdir(parents=True, exist_ok=True) + screenshot_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir) + ) + xml_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{round_count}", local_save_dir=task_dir) + ) + if not screenshot_path.exists() or not xml_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + clickable_list = [] + focusable_list = [] + traverse_xml_tree(xml_path, clickable_list, "clickable", True) + traverse_xml_tree(xml_path, focusable_list, "focusable", True) + elem_list: list[AndroidElement] = clickable_list.copy() + for elem in focusable_list: + bbox = elem.bbox + center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + close = False + for e in clickable_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= extra_config.get("min_dist", 30): + close = True + break + if not close: + elem_list.append(elem) + + screenshot_labeled_path = task_dir.joinpath(f"{round_count}_labeled.png") + draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list) + img_base64 = encode_image(screenshot_labeled_path) + + parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template + + if grid_on: + env.rows, env.cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png") + + ui_doc = self._makeup_ui_document(elem_list, docs_dir) + context = parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act) + node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64]) + + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + + prompt = node.compile(context=context, schema="json", mode="auto") + OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_labeled_path), response=node.content) + + op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on) + if op_param.param_state == RunState.FINISH: + logger.info(f"op_param: {op_param}") + return AndroidActionOutput(action_state=RunState.FINISH) + if op_param.param_state == RunState.FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + last_act = op_param.last_act + if isinstance(op_param, TapOpParam): + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + elif isinstance(op_param, TextOpParam): + action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=op_param.input_str) + elif isinstance(op_param, LongPressOpParam): + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + elif isinstance(op_param, SwipeOpParam): + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction( + action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=op_param.swipe_orient, dist=op_param.dist + ) + elif isinstance(op_param, GridOpParam): + grid_on = True + elif isinstance(op_param, TapGridOpParam) or isinstance(op_param, LongPressGridOpParam): + x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols) + if isinstance(op_param, TapGridOpParam): + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + else: + # LongPressGridOpParam + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + elif isinstance(op_param, SwipeGridOpParam): + start_x, start_y = area_to_xy( + op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols + ) + end_x, end_y = area_to_xy( + op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols + ) + action = EnvAction( + action_type=EnvActionType.USER_SWIPE_TO, coord=(start_x, start_y), tgt_coord=(end_x, end_y) + ) + + if not grid_on: + obs, _, _, _, info = env.step(action) + action_res = info["res"] + if action_res == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + if op_param.act_name != "grid": + grid_on = False + + return AndroidActionOutput(data={"grid_on": grid_on, "last_act": last_act}) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/screenshot_parse_an.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/screenshot_parse_an.py new file mode 100644 index 0000000000000000000000000000000000000000..eb23ba93445c112d0e469b5628ed3e752f03fd3f --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/screenshot_parse_an.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the ActionNode to parse screenshot + +from metagpt.actions.action_node import ActionNode + +OBSERVATION = ActionNode( + key="Observation", expected_type=str, instruction="Describe what you observe in the image", example="" +) + +THOUGHT = ActionNode( + key="Thought", + expected_type=str, + instruction="To complete the given task, what is the next step I should do", + example="", +) + +ACTION = ActionNode( + key="Action", + expected_type=str, + instruction="The function call with the correct parameters to proceed with the task. If you believe the task is " + "completed or there is nothing to be done, you should output FINISH. You cannot output anything else " + "except a function call or FINISH in this field.", + example="", +) + +SUMMARY = ActionNode( + key="Summary", + expected_type=str, + instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include " + "the numeric tag in your summary", + example="", +) + +SUMMARY_GRID = ActionNode( + key="Summary", + expected_type=str, + instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include " + "the grid area number in your summary", + example="", +) + +NODES = [OBSERVATION, THOUGHT, ACTION, SUMMARY] + +NODES_GRID = [OBSERVATION, THOUGHT, ACTION, SUMMARY_GRID] + +SCREENSHOT_PARSE_NODE = ActionNode.from_children("ScreenshotParse", NODES) +SCREENSHOT_PARSE_GRID_NODE = ActionNode.from_children("ScreenshotParseGrid", NODES_GRID) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/self_learn_and_reflect.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/self_learn_and_reflect.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9cfbb4547599d388598ac53d0c187542645197 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/self_learn_and_reflect.py @@ -0,0 +1,231 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage + +import ast +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) +from metagpt.ext.android_assistant.actions.screenshot_parse_an import ( + SCREENSHOT_PARSE_NODE, +) +from metagpt.ext.android_assistant.actions.self_learn_reflect_an import ( + SELF_LEARN_REFLECT_NODE, +) +from metagpt.ext.android_assistant.prompts.assistant_prompt import ( + screenshot_parse_self_explore_reflect_template as reflect_template, +) +from metagpt.ext.android_assistant.prompts.assistant_prompt import ( + screenshot_parse_self_explore_template, +) +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidActionOutput, + AndroidElement, + Decision, + DocContent, + LongPressOpParam, + OpLogItem, + ReflectLogItem, + RunState, + SwipeOp, + SwipeOpParam, + TapOpParam, + TextOpParam, +) +from metagpt.ext.android_assistant.utils.utils import ( + draw_bbox_multi, + elem_bbox_to_xy, + elem_list_from_xml_tree, + reflect_parse_extarct, + screenshot_parse_extract, +) +from metagpt.logs import logger +from metagpt.utils.common import encode_image + + +class SelfLearnAndReflect(Action): + name: str = "SelfLearnAndReflect" + + useless_list: list[str] = [] # store useless elements uid + + screenshot_before_path: str = "" + screenshot_before_base64: str = "" + elem_list: list[AndroidElement] = [] + swipe_orient: str = "up" + act_name: str = "" + ui_area: int = -1 + + async def run( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + ) -> AndroidActionOutput: + for path in [task_dir, docs_dir]: + path.mkdir(parents=True, exist_ok=True) + resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env) + if resp.action_state != RunState.SUCCESS: + return resp + + resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env) + return resp + + async def run_self_learn( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv + ) -> AndroidActionOutput: + extra_config = config.extra + screenshot_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir) + ) + xml_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{round_count}", local_save_dir=task_dir) + ) + if not screenshot_path.exists() or not xml_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + elem_list = elem_list_from_xml_tree(xml_path, self.useless_list, extra_config.get("min_dist", 30)) + + screenshot_before_labeled_path = task_dir.joinpath(f"{round_count}_before_labeled.png") + draw_bbox_multi(screenshot_path, screenshot_before_labeled_path, elem_list) + img_base64 = encode_image(screenshot_before_labeled_path) + self.screenshot_before_base64 = img_base64 + self.screenshot_before_path = screenshot_before_labeled_path + + self_explore_template = screenshot_parse_self_explore_template + context = self_explore_template.format(task_description=task_desc, last_act=last_act) + + node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64]) + logger.debug(f"fill result:{node}") + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + prompt = node.compile(context=context, schema="json", mode="auto") + # Modify WindowsPath to Str + OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content) + op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False) + # TODO Modify Op_param. When op_param.action is FINISH, how to solve this ? + if op_param.param_state == RunState.FINISH: + return AndroidActionOutput(action_state=RunState.FINISH) + if op_param.param_state == RunState.FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + if isinstance(op_param, TapOpParam): + self.ui_area = op_param.area + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + elif isinstance(op_param, TextOpParam): + action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=op_param.input_str) + elif isinstance(op_param, LongPressOpParam): + self.ui_area = op_param.area + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + elif isinstance(op_param, SwipeOpParam): + self.ui_area = op_param.area + self.swipe_orient = op_param.swipe_orient + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction( + action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=op_param.swipe_orient, dist=op_param.dist + ) + + obs, _, _, _, info = env.step(action) + action_res = info["res"] + if action_res == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + self.elem_list = elem_list + self.act_name = op_param.act_name + return AndroidActionOutput() + + async def run_reflect( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + ) -> AndroidActionOutput: + screenshot_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_after", local_save_dir=task_dir) + ) + if not screenshot_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + screenshot_after_labeled_path = task_dir.joinpath(f"{round_count}_after_labeled.png") + draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list) + img_base64 = encode_image(screenshot_after_labeled_path) + if self.act_name == ActionOp.TAP.value: + action = "tapping" + elif self.act_name == ActionOp.LONG_PRESS.value: + action = "long pressing" + elif self.act_name == ActionOp.SWIPE.value: + action = "swiping" + if self.swipe_orient == SwipeOp.UP.value or self.swipe_orient == SwipeOp.DOWN.value: + action = "v_swipe" + elif self.swipe_orient == SwipeOp.LEFT.value or self.swipe_orient == SwipeOp.RIGHT.value: + action = "h_swipe" + else: + # TODO Test for assignment, This error is eupiped with the next. + logger.warning(f"Current action name parse failed, it's `{self.act_name}`") + action = None + context = reflect_template.format( + action=action, ui_element=str(self.ui_area), task_desc=task_desc, last_act=last_act + ) + node = await SELF_LEARN_REFLECT_NODE.fill( + context=context, llm=self.llm, images=[self.screenshot_before_base64, img_base64] + ) + + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + + prompt = node.compile(context=context, schema="json", mode="auto") + ReflectLogItem( + step=round_count, + prompt=prompt, + image_before=str(self.screenshot_before_path), + image_after=str(screenshot_after_labeled_path), + response=node.content, + ) + + op_param = reflect_parse_extarct(node.instruct_content.model_dump()) + if op_param.param_state == RunState.FINISH: + return AndroidActionOutput(action_state=RunState.FINISH) + if op_param.param_state == RunState.FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + logger.info( + f"reflect_parse_extarct decision: {op_param.decision}, " + f"elem_list size: {len(self.elem_list)}, ui_area: {self.ui_area}" + ) + # TODO here will cause `IndexError: list index out of range`. + # Maybe you should clink back to the desktop in the simulator + resource_id = self.elem_list[int(self.ui_area) - 1].uid + if op_param.decision == Decision.INEFFECTIVE.value: + self.useless_list.append(resource_id) + last_act = "NONE" # TODO global + elif op_param.decision in [Decision.BACK.value, Decision.CONTINUE.value, Decision.SUCCESS.value]: + if op_param.decision in [Decision.BACK.value, Decision.CONTINUE.value]: + self.useless_list.append(resource_id) + last_act = "NONE" + if op_param.decision == Decision.BACK.value: + action = EnvAction(action_type=EnvActionType.SYSTEM_BACK) + obs, _, _, _, info = env.step(action) + if info["res"] == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + doc = op_param.documentation + doc_path = docs_dir.joinpath(f"{resource_id}.txt") + if doc_path.exists(): + try: + doc_content = ast.literal_eval(doc_path.read_text()) + except Exception as exp: + logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}") + return AndroidActionOutput(action_state=RunState.FAIL) + + if doc_content[self.act_name]: + logger.info(f"Documentation for the element {resource_id} already exists.") + return AndroidActionOutput(action_state=RunState.FAIL) + else: + doc_content = DocContent() + setattr(doc_content, self.act_name, doc) + doc_path.write_text(str(doc_content)) + return AndroidActionOutput(data={"last_act": last_act}) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/actions/self_learn_reflect_an.py b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/self_learn_reflect_an.py new file mode 100644 index 0000000000000000000000000000000000000000..305b7376af469fd3d03bbf04907a3d486fbd173b --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/actions/self_learn_reflect_an.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the ActionNode to parse Reflection + +from metagpt.actions.action_node import ActionNode + +DECISION = ActionNode( + key="Decision", expected_type=str, instruction="explain why you made this decision", example="BACK" +) + + +THOUGHT = ActionNode(key="Thought", expected_type=str, instruction="explain why you made this decision", example="") + + +DOCUMENTATION = ActionNode( + key="Documentation", expected_type=str, instruction="describe the function of the UI element", example="" +) + + +NODES = [DECISION, THOUGHT, DOCUMENTATION] +SELF_LEARN_REFLECT_NODE = ActionNode.from_children("SelfLearnReflect", NODES) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/__init__.py b/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/assistant_prompt.py b/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/assistant_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..34baf58417ca1e0206d2d5730ce29fe11e78530a --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/assistant_prompt.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the prompt templates of assistant learning and acting + +screenshot_parse_template = """You are an agent that is trained to perform some basic tasks on a smartphone. You will be given a +smartphone screenshot. The interactive UI elements on the screenshot are labeled with numeric tags starting from 1. The +numeric tag of each interactive element is located in the center of the element. + +You can call the following functions to control the smartphone: + +1. tap(element: int) +This function is used to tap an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be tap(5), which taps the UI element labeled with the number 5. + +2. text(text_input: str) +This function is used to insert text input in an input field/box. text_input is the string you want to insert and must +be wrapped with double quotation marks. A simple use case can be text("Hello, world!"), which inserts the string +"Hello, world!" into the input area on the smartphone screen. This function is usually callable when you see a keyboard +showing in the lower half of the screen. + +3. long_press(element: int) +This function is used to long press an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be long_press(5), which long presses the UI element labeled with the number 5. + +4. swipe(element: int, direction: str, dist: str) +This function is used to swipe an UI element shown on the smartphone screen, usually a scroll view or a slide bar. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. "direction" is a string that +represents one of the four directions: up, down, left, right. "direction" must be wrapped with double quotation +marks. "dist" determines the distance of the swipe and can be one of the three options: short, medium, long. You should +choose the appropriate distance option according to your need. +A simple use case can be swipe(21, "up", "medium"), which swipes up the UI element labeled with the number 21 for a +medium distance. + +5. grid() +You should call this function when you find the element you want to interact with is not labeled with a numeric tag and +other elements with numeric tags cannot help with the task. The function will bring up a grid overlay to divide the +smartphone screen into small areas and this will give you more freedom to choose any part of the screen to tap, long +press, or swipe. +{ui_document} +The task you need to complete is to: {task_description}. Your past actions to proceed with this task are summarized as +follows: {last_act} +Now, given the documentation and the following labeled screenshot, you need to think and call the function needed to +proceed with the task. Your output should include three parts in the given format: + +You can only take one action at a time, so please directly call the function.""" + +screenshot_parse_with_grid_template = """You are an agent that is trained to perform some basic tasks on a smartphone. You will be given +a smartphone screenshot overlaid by a grid. The grid divides the screenshot into small square areas. Each area is +labeled with an integer in the top-left corner. + +You can call the following functions to control the smartphone: + +1. tap(area: int, subarea: str) +This function is used to tap a grid area shown on the smartphone screen. "area" is the integer label assigned to a grid +area shown on the smartphone screen. "subarea" is a string representing the exact location to tap within the grid area. +It can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, bottom, and +bottom-right. +A simple use case can be tap(5, "center"), which taps the exact center of the grid area labeled with the number 5. + +2. long_press(area: int, subarea: str) +This function is used to long press a grid area shown on the smartphone screen. "area" is the integer label assigned to +a grid area shown on the smartphone screen. "subarea" is a string representing the exact location to long press within +the grid area. It can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, bottom, +and bottom-right. +A simple use case can be long_press(7, "top-left"), which long presses the top left part of the grid area labeled with +the number 7. + +3. swipe(start_area: int, start_subarea: str, end_area: int, end_subarea: str) +This function is used to perform a swipe action on the smartphone screen, especially when you want to interact with a +scroll view or a slide bar. "start_area" is the integer label assigned to the grid area which marks the starting +location of the swipe. "start_subarea" is a string representing the exact location to begin the swipe within the grid +area. "end_area" is the integer label assigned to the grid area which marks the ending location of the swipe. +"end_subarea" is a string representing the exact location to end the swipe within the grid area. +The two subarea parameters can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, +bottom, and bottom-right. +A simple use case can be swipe(21, "center", 25, "right"), which performs a swipe starting from the center of grid area +21 to the right part of grid area 25. + +The task you need to complete is to: {task_description}. Your past actions to proceed with this task are summarized as +follows: {last_act} +Now, given the following labeled screenshot, you need to think and call the function needed to proceed with the task. +Your output should include three parts in the given format: + +You can only take one action at a time, so please directly call the function.""" + +screenshot_parse_self_explore_template = """You are an agent that is trained to complete certain tasks on a smartphone. You will be +given a screenshot of a smartphone app. The interactive UI elements on the screenshot are labeled with numeric tags +starting from 1. + +You can call the following functions to interact with those labeled elements to control the smartphone: + +1. tap(element: int) +This function is used to tap an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be tap(5), which taps the UI element labeled with the number 5. + +2. text(text_input: str) +This function is used to insert text input in an input field/box. text_input is the string you want to insert and must +be wrapped with double quotation marks. A simple use case can be text("Hello, world!"), which inserts the string +"Hello, world!" into the input area on the smartphone screen. This function is only callable when you see a keyboard +showing in the lower half of the screen. + +3. long_press(element: int) +This function is used to long press an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be long_press(5), which long presses the UI element labeled with the number 5. + +4. swipe(element: int, direction: str, dist: str) +This function is used to swipe an UI element shown on the smartphone screen, usually a scroll view or a slide bar. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. "direction" is a string that +represents one of the four directions: up, down, left, right. "direction" must be wrapped with double quotation +marks. "dist" determines the distance of the swipe and can be one of the three options: short, medium, long. You should +choose the appropriate distance option according to your need. +A simple use case can be swipe(21, "up", "medium"), which swipes up the UI element labeled with the number 21 for a +medium distance. + +The task you need to complete is to {task_description}. Your past actions to proceed with this task are summarized as +follows: {last_act} +Now, given the following labeled screenshot, you need to think and call the function needed to proceed with the task. +Your output should include three parts in the given format: + +You can only take one action at a time, so please directly call the function.""" + +screenshot_parse_self_explore_reflect_template = """I will give you screenshots of a mobile app before and after {action} the UI +element labeled with the number '{ui_element}' on the first screenshot. The numeric tag of each element is located at +the center of the element. The action of {action} this UI element was described as follows: +{last_act} +The action was also an attempt to proceed with a larger task, which is to {task_desc}. Your job is to carefully analyze +the difference between the two screenshots to determine if the action is in accord with the description above and at +the same time effectively moved the task forward. Your output should be determined based on the following situations: +1. BACK +If you think the action navigated you to a page where you cannot proceed with the given task, you should go back to the +previous interface. At the same time, describe the functionality of the UI element concisely in one or two sentences by +observing the difference between the two screenshots. Notice that your description of the UI element should focus on +the general function. Never include the numeric tag of the UI element in your description. You can use pronouns such as +"the UI element" to refer to the element. Your output should be in the following format: +Decision: BACK +Thought: +Documentation: +2. INEFFECTIVE +If you find the action changed nothing on the screen (screenshots before and after the action are identical), you +should continue to interact with other elements on the screen. Notice that if you find the location of the cursor +changed between the two screenshots, then they are not identical. Your output should be in the following format: +Decision: INEFFECTIVE +Thought: +Documentation: +3. CONTINUE +If you find the action changed something on the screen but does not reflect the action description above and did not +move the given task forward, you should continue to interact with other elements on the screen. At the same time, +describe the functionality of the UI element concisely in one or two sentences by observing the difference between the +two screenshots. Notice that your description of the UI element should focus on the general function. Never include the +numeric tag of the UI element in your description. You can use pronouns such as "the UI element" to refer to the +element. Your output should be in the following format: +Decision: CONTINUE +Thought: +Documentation: +4. SUCCESS +If you think the action successfully moved the task forward (even though it did not completed the task), you should +describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the UI +element should focus on the general function. Never include the numeric tag of the UI element in your description. You +can use pronouns such as "the UI element" to refer to the element. Your output should be in the following format: +Decision: SUCCESS +Thought: +Documentation: +""" diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/operation_prompt.py b/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/operation_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..1bde53f04197b50e75d6caf3ce1847402b4a3a9d --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/prompts/operation_prompt.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the prompt templates of phone operation + +tap_doc_template = """I will give you the screenshot of a mobile app before and after tapping the UI element labeled +with the number {ui_element} on the screen. The numeric tag of each element is located at the center of the element. +Tapping this UI element is a necessary part of proceeding with a larger task, which is to . Your task is to +describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the UI +element should focus on the general function. For example, if the UI element is used to navigate to the chat window +with John, your description should not include the name of the specific person. Just say: "Tapping this area will +navigate the user to the chat window". Never include the numeric tag of the UI element in your description. You can use +pronouns such as "the UI element" to refer to the element.""" + +text_doc_template = """I will give you the screenshot of a mobile app before and after typing in the input area labeled +with the number {ui_element} on the screen. The numeric tag of each element is located at the center of the element. +Typing in this UI element is a necessary part of proceeding with a larger task, which is to . Your task is +to describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the +UI element should focus on the general function. For example, if the change of the screenshot shows that the user typed +"How are you?" in the chat box, you do not need to mention the actual text. Just say: "This input area is used for the +user to type a message to send to the chat window.". Never include the numeric tag of the UI element in your +description. You can use pronouns such as "the UI element" to refer to the element.""" + +long_press_doc_template = """I will give you the screenshot of a mobile app before and after long pressing the UI +element labeled with the number {ui_element} on the screen. The numeric tag of each element is located at the center of +the element. Long pressing this UI element is a necessary part of proceeding with a larger task, which is to +. Your task is to describe the functionality of the UI element concisely in one or two sentences. Notice +that your description of the UI element should focus on the general function. For example, if long pressing the UI +element redirects the user to the chat window with John, your description should not include the name of the specific +person. Just say: "Long pressing this area will redirect the user to the chat window". Never include the numeric tag of +the UI element in your description. You can use pronouns such as "the UI element" to refer to the element.""" + +swipe_doc_template = """I will give you the screenshot of a mobile app before and after swiping the UI +element labeled with the number {ui_element} on the screen. The numeric tag of each element is located at the center of +the element. Swiping this UI element is a necessary part of proceeding with a larger task, which is to . +Your task is to describe the functionality of the UI element concisely in one or two sentences. Notice that your +description of the UI element should be as general as possible. For example, if swiping the UI element increases the +contrast ratio of an image of a building, your description should be just like this: "Swiping this area enables the +user to tune a specific parameter of the image". Never include the numeric tag of the UI element in your description. +You can use pronouns such as "the UI element" to refer to the element.""" + +refine_doc_suffix = """\nA documentation of this UI element generated from previous demos is shown below. Your +generated description should be based on this previous doc and optimize it. Notice that it is possible that your +understanding of the function of the UI element derived from the given screenshots conflicts with the previous doc, +because the function of a UI element can be flexible. In this case, your generated description should combine both. +Old documentation of this UI element: {old_doc}""" diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/roles/__init__.py b/notebook_dir/metagpt_yusin/ext/android_assistant/roles/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/roles/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/roles/android_assistant.py b/notebook_dir/metagpt_yusin/ext/android_assistant/roles/android_assistant.py new file mode 100644 index 0000000000000000000000000000000000000000..45636f5191e2120ad2fdc92ca4bbb3583e10d069 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/roles/android_assistant.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : android assistant to learn from app operations and operate apps +import time +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import Field + +from metagpt.actions.add_requirement import UserRequirement +from metagpt.config2 import config +from metagpt.const import EXAMPLE_PATH +from metagpt.ext.android_assistant.actions.manual_record import ManualRecord +from metagpt.ext.android_assistant.actions.parse_record import ParseRecord +from metagpt.ext.android_assistant.actions.screenshot_parse import ScreenshotParse +from metagpt.ext.android_assistant.actions.self_learn_and_reflect import ( + SelfLearnAndReflect, +) +from metagpt.ext.android_assistant.utils.schema import AndroidActionOutput, RunState +from metagpt.logs import logger +from metagpt.roles.role import Role, RoleReactMode +from metagpt.schema import Message + + +class AndroidAssistant(Role): + name: str = "Nick" + profile: str = "AndroidAssistant" + goal: str = "operate the mobile phone's apps with self-learn" + + task_desc: str = "" + round_count: int = 0 + last_act: str = "None" + output_root_dir: Optional[Path] = Field(default=None) + task_dir: Optional[Path] = Field(default=None) + docs_dir: Optional[Path] = Field(default=None) + grid_on: bool = Field(default=False) + + def __init__(self, **data): + super().__init__(**data) + + self._watch([UserRequirement, AndroidActionOutput]) + extra_config = config.extra + self.task_desc = extra_config.get("task_desc", "Just explore any app in this phone!") + app_name = extra_config.get("app_name", "demo") + data_dir = self.output_root_dir.absolute().joinpath("output") or EXAMPLE_PATH.joinpath( + "android_assistant/output" + ) + cur_datetime = datetime.fromtimestamp(int(time.time())).strftime("%Y-%m-%d_%H-%M-%S") + + """Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app, + run the learn first and then do the act stage or learn it during the action. + """ + stage = extra_config.get("stage") + mode = extra_config.get("mode") + if stage == "learn" and mode == "manual": + # choose ManualRecord and then run ParseRecord + # Remember, only run each action only one time, no need to run n_round. + self.set_actions([ManualRecord, ParseRecord]) + self.task_dir = data_dir.joinpath(app_name, f"manual_learn_{cur_datetime}") + self.docs_dir = data_dir.joinpath(app_name, "manual_docs") + elif stage == "learn" and mode == "auto": + # choose SelfLearnAndReflect to run + self.set_actions([SelfLearnAndReflect]) + self.task_dir = data_dir.joinpath(app_name, f"auto_learn_{cur_datetime}") + self.docs_dir = data_dir.joinpath(app_name, "auto_docs") + elif stage == "act": + # choose ScreenshotParse to run + self.set_actions([ScreenshotParse]) + self.task_dir = data_dir.joinpath(app_name, f"act_{cur_datetime}") + if mode == "manual": + self.docs_dir = data_dir.joinpath(app_name, "manual_docs") + else: + self.docs_dir = data_dir.joinpath(app_name, "auto_docs") + else: + raise ValueError(f"invalid stage: {stage}, mode: {mode}") + + self._check_dir() + + self._set_react_mode(RoleReactMode.BY_ORDER) + + def _check_dir(self): + self.task_dir.mkdir(parents=True, exist_ok=True) + self.docs_dir.mkdir(parents=True, exist_ok=True) + + async def react(self) -> Message: + self.round_count += 1 + result = await super().react() + logger.debug(f"react result {result}") + return result + + async def _observe(self, ignore_memory=True) -> int: + """ignore old memory to make it run multi rounds inside a role""" + newest_msgs = self.rc.memory.get(k=1) + newest_msg = newest_msgs[0] if newest_msgs else None + if newest_msg and (RunState.SUCCESS.value.upper() not in newest_msg.content): + ignore_memory = False + state_val = newest_msg.content.split(".")[-1] # RoundCount: 1, action_state: RunState.SUCCESS + logger.warning(f"Latest action_state is {state_val}, will run in the remainder rounds without `react`") + return await super()._observe(ignore_memory) + + async def _act(self) -> Message: + logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") + todo = self.rc.todo + if isinstance(todo, ManualRecord): + resp = await todo.run(task_dir=self.task_dir, task_desc=self.task_desc, env=self.rc.env) + elif isinstance(todo, ParseRecord): + resp = await todo.run( + task_dir=self.task_dir, + docs_dir=self.docs_dir, + ) + elif isinstance(todo, SelfLearnAndReflect): + resp = await todo.run( + round_count=self.round_count, + task_desc=self.task_desc, + last_act=self.last_act, + task_dir=self.task_dir, + docs_dir=self.docs_dir, + env=self.rc.env, + ) + if resp.action_state == RunState.SUCCESS: + self.last_act = resp.data.get("last_act") + elif isinstance(todo, ScreenshotParse): + resp = await todo.run( + round_count=self.round_count, + task_desc=self.task_desc, + last_act=self.last_act, + task_dir=self.task_dir, + docs_dir=self.docs_dir, + grid_on=self.grid_on, + env=self.rc.env, + ) + if resp.action_state == RunState.SUCCESS: + logger.info(f"grid_on: {resp.data.get('grid_on')}") + self.grid_on = resp.data.get("grid_on", False) + self.last_act = resp.data.get("last_act", "None") + msg = Message( + content=f"RoundCount: {self.round_count}, action_state: {resp.action_state}", + role=self.profile, + cause_by=type(resp), + send_from=self.name, + send_to=self.name, + ) + + self.rc.memory.add(msg) + return msg diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/utils/__init__.py b/notebook_dir/metagpt_yusin/ext/android_assistant/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/utils/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/utils/schema.py b/notebook_dir/metagpt_yusin/ext/android_assistant/utils/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c066f98b626acc02c6dedfc553edf8f249db4524 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/utils/schema.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from enum import Enum + +from pydantic import BaseModel, Field, field_validator + + +class ActionOp(Enum): + TAP = "tap" + LONG_PRESS = "long_press" + TEXT = "text" + SWIPE = "swipe" + VERTICAL_SWIPE = "v_swipe" + HORIZONTAL_SWIPE = "h_swipe" + GRID = "grid" + STOP = "stop" + + +class SwipeOp(Enum): + UP = "up" + DOWN = "down" + LEFT = "left" + RIGHT = "right" + + +class Decision(Enum): + BACK = "BACK" + INEFFECTIVE = "INEFFECTIVE" + CONTINUE = "CONTINUE" + SUCCESS = "SUCCESS" + + @classmethod + def values(cls): + return [item.value for item in cls] + + +class AndroidElement(BaseModel): + """UI Element""" + + uid: str = Field(default="") + bbox: tuple[tuple[int, int], tuple[int, int]] = Field(default={}) + attrib: str = Field(default="") + + +class OpLogItem(BaseModel): + """log content for self-learn or task act""" + + step: int = Field(default=0) + prompt: str = Field(default="") + image: str = Field(default="") + response: str = Field(default="") + + +class ReflectLogItem(BaseModel): + """log content for self-learn-reflect""" + + step: int = Field(default=0) + prompt: str = Field(default="") + image_before: str = Field(default="") + image_after: str = Field(default="") + response: str = Field(default="") + + +class RecordLogItem(BaseModel): + """log content for record parse, same as ReflectLogItem""" + + step: int = Field(default=0) + prompt: str = Field(default="") + image_before: str = Field(default="") + image_after: str = Field(default="") + response: str = Field(default="") + + +class DocContent(BaseModel): + tap: str = Field(default="") + text: str = Field(default="") + v_swipe: str = Field(default="") + h_swipe: str = Field(default="") + long_press: str = Field(default="") + + +# start =================== define different Action Op and its params ============= +class RunState(Enum): + """run state""" + + SUCCESS = "success" + FINISH = "finish" + FAIL = "fail" + + +class BaseOpParam(BaseModel): + act_name: str = Field(default="", validate_default=True) + last_act: str = Field(default="None") + param_state: RunState = Field(default=RunState.SUCCESS, description="return state when extract params") + + +class TapOpParam(BaseOpParam): + area: int = Field(default=-1) + + +class TextOpParam(BaseOpParam): + input_str: str = Field(default="") + + +class LongPressOpParam(BaseOpParam): + area: int = Field(default=-1) + + +# Modify This SwipeOp to SwipeOpParam, Need better name +class SwipeOpParam(BaseOpParam): + area: int = Field(default=-1) + swipe_orient: str = Field(default="up") + dist: str = Field(default="") + + +class GridOpParam(BaseOpParam): + act_name: str = Field(default="") + + +class BaseGridOpParam(BaseOpParam): + @field_validator("act_name", mode="before") + @classmethod + def check_act_name(cls, act_name: str) -> str: + return f"{act_name}_grid" + + +class TapGridOpParam(BaseGridOpParam): + area: int = Field(default=-1) + subarea: str = Field(default="") + + +class LongPressGridOpParam(BaseGridOpParam): + area: int = Field(default=-1) + subarea: str = Field(default="") + + +class SwipeGridOpParam(BaseGridOpParam): + start_area: int = Field(default=-1) + start_subarea: str = Field(default="") + end_area: int = Field(default=-1) + end_subarea: str = Field(default="") + + +# end =================== define different Action Op and its params ============= + + +class ReflectOp(BaseModel): + decision: str = "" + thought: str = "" + documentation: str = "" + param_state: RunState = RunState.SUCCESS + + +class AndroidActionOutput(BaseModel): + data: dict = Field(default=dict()) + action_state: RunState = Field(default=RunState.SUCCESS) diff --git a/notebook_dir/metagpt_yusin/ext/android_assistant/utils/utils.py b/notebook_dir/metagpt_yusin/ext/android_assistant/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fa138692ca3187144e5a9f75211fa607f26b5a --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/android_assistant/utils/utils.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import re +from pathlib import Path +from typing import Union +from xml.etree.ElementTree import Element, iterparse + +import cv2 +import pyshine as ps + +from metagpt.config2 import config +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidElement, + BaseGridOpParam, + BaseOpParam, + Decision, + GridOpParam, + LongPressGridOpParam, + LongPressOpParam, + ReflectOp, + RunState, + SwipeGridOpParam, + SwipeOpParam, + TapGridOpParam, + TapOpParam, + TextOpParam, +) +from metagpt.logs import logger + + +def get_id_from_element(elem: Element) -> str: + bounds = elem.attrib["bounds"][1:-1].split("][") + x1, y1 = map(int, bounds[0].split(",")) + x2, y2 = map(int, bounds[1].split(",")) + elem_w, elem_h = x2 - x1, y2 - y1 + if "resource-id" in elem.attrib and elem.attrib["resource-id"]: + elem_id = elem.attrib["resource-id"].replace(":", ".").replace("/", "_") + else: + elem_id = f"{elem.attrib['class']}_{elem_w}_{elem_h}" + if "content-desc" in elem.attrib and elem.attrib["content-desc"] and len(elem.attrib["content-desc"]) < 20: + content_desc = elem.attrib["content-desc"].replace("/", "_").replace(" ", "").replace(":", "_") + elem_id += f"_{content_desc}" + return elem_id + + +def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False): + path = [] + extra_config = config.extra + for event, elem in iterparse(str(xml_path), ["start", "end"]): + if event == "start": + path.append(elem) + if attrib in elem.attrib and elem.attrib[attrib] == "true": + parent_prefix = "" + if len(path) > 1: + parent_prefix = get_id_from_element(path[-2]) + bounds = elem.attrib["bounds"][1:-1].split("][") + x1, y1 = map(int, bounds[0].split(",")) + x2, y2 = map(int, bounds[1].split(",")) + center = (x1 + x2) // 2, (y1 + y2) // 2 + elem_id = get_id_from_element(elem) + if parent_prefix: + elem_id = parent_prefix + "_" + elem_id + if add_index: + elem_id += f"_{elem.attrib['index']}" + close = False + for e in elem_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= extra_config.get("min_dist", 30): + close = True + break + if not close: + elem_list.append(AndroidElement(uid=elem_id, bbox=((x1, y1), (x2, y2)), attrib=attrib)) + + if event == "end": + path.pop() + + +def elem_list_from_xml_tree(xml_path: Path, useless_list: list[str], min_dist: int) -> list[AndroidElement]: + clickable_list = [] + focusable_list = [] + traverse_xml_tree(xml_path, clickable_list, "clickable", True) + traverse_xml_tree(xml_path, focusable_list, "focusable", True) + elem_list = [] + for elem in clickable_list: + if elem.uid in useless_list: + continue + elem_list.append(elem) + for elem in focusable_list: + if elem.uid in useless_list: + continue + bbox = elem.bbox + center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + close = False + for e in clickable_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= min_dist: + close = True + break + if not close: + elem_list.append(elem) + return elem_list + + +def draw_bbox_multi( + img_path: Path, + output_path: Path, + elem_list: list[AndroidElement], + record_mode: bool = False, + dark_mode: bool = False, +): + imgcv = cv2.imread(str(img_path)) + count = 1 + for elem in elem_list: + try: + top_left = elem.bbox[0] + bottom_right = elem.bbox[1] + left, top = top_left[0], top_left[1] + right, bottom = bottom_right[0], bottom_right[1] + label = str(count) + if record_mode: + if elem.attrib == "clickable": + color = (250, 0, 0) + elif elem.attrib == "focusable": + color = (0, 0, 250) + else: + color = (0, 250, 0) + imgcv = ps.putBText( + imgcv, + label, + text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, + hspace=10, + font_scale=1, + thickness=2, + background_RGB=color, + text_RGB=(255, 250, 250), + alpha=0.5, + ) + else: + text_color = (10, 10, 10) if dark_mode else (255, 250, 250) + bg_color = (255, 250, 250) if dark_mode else (10, 10, 10) + imgcv = ps.putBText( + imgcv, + label, + text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, + hspace=10, + font_scale=1, + thickness=2, + background_RGB=bg_color, + text_RGB=text_color, + alpha=0.5, + ) + except Exception as e: + logger.error(f"ERROR: An exception occurs while labeling the image\n{e}") + count += 1 + cv2.imwrite(str(output_path), imgcv) + return imgcv + + +def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]: + def get_unit_len(n): + for i in range(1, n + 1): + if n % i == 0 and 120 <= i <= 180: + return i + return -1 + + image = cv2.imread(str(img_path)) + height, width, _ = image.shape + color = (255, 116, 113) + unit_height = get_unit_len(height) + if unit_height < 0: + unit_height = 120 + unit_width = get_unit_len(width) + if unit_width < 0: + unit_width = 120 + thick = int(unit_width // 50) + rows = height // unit_height + cols = width // unit_width + for i in range(rows): + for j in range(cols): + label = i * cols + j + 1 + left = int(j * unit_width) + top = int(i * unit_height) + right = int((j + 1) * unit_width) + bottom = int((i + 1) * unit_height) + cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2) + cv2.putText( + image, + str(label), + (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), + 0, + int(0.01 * unit_width), + (0, 0, 0), + thick, + ) + cv2.putText( + image, + str(label), + (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), + 0, + int(0.01 * unit_width), + color, + thick, + ) + cv2.imwrite(str(output_path), image) + return rows, cols + + +def area_to_xy(area: int, subarea: str, width: int, height: int, rows: int, cols: int) -> tuple[int, int]: + area -= 1 + row, col = area // cols, area % cols + x_0, y_0 = col * (width // cols), row * (height // rows) + if subarea == "top-left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 4 + elif subarea == "top": + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 4 + elif subarea == "top-right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 4 + elif subarea == "left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 2 + elif subarea == "right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 2 + elif subarea == "bottom-left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) * 3 // 4 + elif subarea == "bottom": + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) * 3 // 4 + elif subarea == "bottom-right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) * 3 // 4 + else: + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 2 + return x, y + + +def elem_bbox_to_xy(bbox: tuple[tuple[int, int], tuple[int, int]]) -> tuple[int, int]: + tl, br = bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + return x, y + + +def reflect_parse_extarct(parsed_json: dict) -> ReflectOp: + decision = parsed_json.get("Decision") + if decision not in Decision.values(): + op = ReflectOp(param_state=RunState.FAIL) + else: + op = ReflectOp( + decision=parsed_json.get("Decision"), + thought=parsed_json.get("Thought"), + documentation=parsed_json.get("Documentation"), + ) + return op + + +def screenshot_parse_extract( + parsed_json: dict, grid_on: bool = False +) -> Union[BaseOpParam, BaseGridOpParam, GridOpParam]: + act = parsed_json.get("Action") + last_act = parsed_json.get("Summary") + act_name = act.split("(")[0] + + if RunState.FINISH.value.upper() in act: + return BaseOpParam(param_state=RunState.FINISH) + + if grid_on: + return screenshot_parse_extract_with_grid(act_name, act, last_act) + else: + return screenshot_parse_extract_without_grid(act_name, act, last_act) + + +def op_params_clean(params: list[str]) -> list[Union[int, str]]: + param_values = [] + for param_value in params: + if '"' in param_value or "'" in param_value: # remove `"` + param_values.append(param_value.strip()[1:-1]) + else: + param_values.append(int(param_value)) + return param_values + + +def screenshot_parse_extract_without_grid(act_name: str, act: str, last_act: str) -> Union[BaseOpParam, GridOpParam]: + if act_name == ActionOp.TAP.value: + area = int(re.findall(r"tap\((.*?)\)", act)[0]) + op = TapOpParam(act_name=act_name, area=area, last_act=last_act) + elif act_name == ActionOp.TEXT.value: + input_str = re.findall(r"text\((.*?)\)", act)[0][1:-1] + op = TextOpParam(act_name=act_name, input_str=input_str, last_act=last_act) + elif act_name == ActionOp.LONG_PRESS.value: + area = int(re.findall(r"long_press\((.*?)\)", act)[0]) + op = LongPressOpParam(act_name=act_name, area=area, last_act=last_act) + elif act_name == ActionOp.SWIPE.value: + params = re.findall(r"swipe\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) # area, swipe_orient, dist + op = SwipeOpParam(act_name=act_name, area=params[0], swipe_orient=params[1], dist=params[2], last_act=last_act) + elif act_name == ActionOp.GRID.value: + op = GridOpParam(act_name=act_name) + else: + op = BaseOpParam(param_state=RunState.FAIL) + return op + + +def screenshot_parse_extract_with_grid(act_name: str, act: str, last_act: str) -> Union[BaseGridOpParam, GridOpParam]: + if act_name == ActionOp.TAP.value: + params = re.findall(r"tap\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) + op = TapGridOpParam(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act) + elif act_name == ActionOp.LONG_PRESS.value: + params = re.findall(r"long_press\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) + op = LongPressGridOpParam(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act) + elif act_name == ActionOp.SWIPE.value: + params = re.findall(r"swipe\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) + op = SwipeGridOpParam( + act_name=act_name, start_area=params[0], start_subarea=params[1], end_area=params[2], end_subarea=params[3] + ) + elif act_name == ActionOp.GRID.value: + op = GridOpParam(act_name=act_name) + else: + op = BaseGridOpParam(param_state=RunState.FAIL) + return op diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/README.md b/notebook_dir/metagpt_yusin/ext/stanford_town/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1bdcac145f047b51614645ab6f7fd7ce6292d5f7 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/README.md @@ -0,0 +1,51 @@ +## Stanford Town Game + +### Pre-Description +In order to facilitate GA( [generative_agents](https://github.com/joonspk-research/generative_agents) )'s frontend docking data (to avoid changing its code), you can set the value `temp_storage_path` to `temp_storage` of `generative_agents` when start `run_st_game.py`. like + +`python3 run_st_game.py --temp_storage_path path/to/ga/temp_storage xxx` + +Or change the path under `const.py` like beflow + +``` +STORAGE_PATH = EXAMPLE_PATH.joinpath("storage") +TEMP_STORAGE_PATH = EXAMPLE_PATH.joinpath("temp_storage") +# updated +STORAGE_PATH = Path("{path/to/ga/storage}") +TEMP_STORAGE_PATH = Path("{path/to/ga/temp_storage}") +``` + +This can be used to achieve docking of simulation data without changing the GA code. Otherwise, the GA code must be modified to adapt to the MG output path. + +If you don't want to start from 0, copy other simulation directories under `generative_agents/environment/frontend_server/storage/` to `examples/stanford_town/storage`, and select a directory named `fork_sim_code`. + +### Backend service startup +The execution entry is `python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10` +or +`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10 --temp_storage_path path/to/ga/temp_storage` + +`idea` is the user's voice to the first Agent, and it is disseminated through this voice to see whether the final multi-agents achieve the goal of hosting or participating in the event. + +### Frontend service startup +Enter project folder `generative_agents` + +Enter `environment/frontend_server` and use `python3 manage.py runserver` to start the front-end service. +Visit `http://localhost:8000/simulator_home` to enter the current simulation interface. + +## Acknowledgements +The reproduction work has referred the [generative_agents](https://github.com/joonspk-research/generative_agents), let's make a general statement here. + +### Citation +```bib +@inproceedings{Park2023GenerativeAgents, +author = {Park, Joon Sung and O'Brien, Joseph C. and Cai, Carrie J. and Morris, Meredith Ringel and Liang, Percy and Bernstein, Michael S.}, +title = {Generative Agents: Interactive Simulacra of Human Behavior}, +year = {2023}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +booktitle = {In the 36th Annual ACM Symposium on User Interface Software and Technology (UIST '23)}, +keywords = {Human-AI interaction, agents, generative AI, large language models}, +location = {San Francisco, CA, USA}, +series = {UIST '23} +} +``` \ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/README_CN.md b/notebook_dir/metagpt_yusin/ext/stanford_town/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..3daf68d08f4494a1137cf3ff4a981c85ed41f4cd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/README_CN.md @@ -0,0 +1,50 @@ +## Stanford Town Game + +### 前置 +为了方便GA( [generative_agents](https://github.com/joonspk-research/generative_agents) )的前端对接数据(避免改动它那块的代码),可在启动`run_st_game.py`加上`temp_storage_path`指向`generative_agents`对应的`temp_storage`路径。比如 + +`python3 run_st_game.py --temp_storage_path path/to/ga/temp_storage xxx` + +或将`const.py`下的 + +``` +STORAGE_PATH = EXAMPLE_PATH.joinpath("storage") +TEMP_STORAGE_PATH = EXAMPLE_PATH.joinpath("temp_storage") +# 更新为 +STORAGE_PATH = Path("{path/to/ga/storage}") +TEMP_STORAGE_PATH = Path("{path/to/ga/temp_storage}") +``` +这样可用实现不改变GA代码情况下,实现仿真数据的对接。不然得修改GA的代码来适配MG的输出路径。 + +如果你不想从0开始启动,拷贝`generative_agents/environment/frontend_server/storage/`下的其他仿真目录到`examples/stanford_town/storage`,并选择一个目录名作为`fork_sim_code`。 + +### 后端服务启动 +执行入口为:`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10` +或者 +`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10 --temp_storage_path path/to/ga/temp_storage` + +`idea`为用户给第一个Agent的用户心声,并通过这个心声进行传播,看最后多智能体是否达到举办、参加活动的目标。 + +### 前端服务启动 +进入`generative_agents`项目目录 + +进入`environment/frontend_server`,使用`python3 manage.py runserver`启动前端服务。 +访问`http://localhost:8000/simulator_home` 进入当前的仿真界面。 + +## 致谢 +复现工作参考了 [generative_agents](https://github.com/joonspk-research/generative_agents), 感谢相关作者们。 + +### 引用 +```bib +@inproceedings{Park2023GenerativeAgents, +author = {Park, Joon Sung and O'Brien, Joseph C. and Cai, Carrie J. and Morris, Meredith Ringel and Liang, Percy and Bernstein, Michael S.}, +title = {Generative Agents: Interactive Simulacra of Human Behavior}, +year = {2023}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +booktitle = {In the 36th Annual ACM Symposium on User Interface Software and Technology (UIST '23)}, +keywords = {Human-AI interaction, agents, generative AI, large language models}, +location = {San Francisco, CA, USA}, +series = {UIST '23} +} +``` diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/__init__.py b/notebook_dir/metagpt_yusin/ext/stanford_town/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56ea35c9f719f30ad6e8b0accf7f4480cefc98bb --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : stanford town implement diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/__init__.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/agent_chat_sum_rel.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/agent_chat_sum_rel.py new file mode 100644 index 0000000000000000000000000000000000000000..98d370bb075b6933a7b53c964b253a2d311c2f97 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/agent_chat_sum_rel.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : summarize relationship in a agent chat + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class AgentChatSumRel(STAction): + name: str = "AgentChatSumRel" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + _ = llm_resp.split('"')[0].strip() + resp = True + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> str: + return llm_resp.split('"')[0].strip() + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, init_role: "STRole", target_role: "STRole", statements: str) -> str: + def create_prompt_input(init_role: "STRole", target_role: "STRole", statements: str) -> str: + prompt_input = [statements, init_role.name, target_role.name] + return prompt_input + + prompt_input = create_prompt_input(init_role, target_role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "summarize_chat_relationship_v2.txt") + + example_output = "Jane Doe is working on a project" + special_instruction = "The output should be a string that responds to the question." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {init_role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/decide_to_talk.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/decide_to_talk.py new file mode 100644 index 0000000000000000000000000000000000000000..a393f31af71495bc3a9ee07e0a6e9d6810ab2fe9 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/decide_to_talk.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : device to talk to another role, return yes or no + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class DecideToTalk(STAction): + name: str = "DecideToTalk" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + if llm_resp.split("Answer in yes or no:")[-1].strip().lower() in ["yes", "no"]: + resp = True + except ValueError: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> str: + return llm_resp.split("Answer in yes or no:")[-1].strip().lower() + + def _func_fail_default_resp(self) -> str: + return "yes" + + async def run(self, init_role: "STRole", target_role: "STRole", retrieved: dict, *args, **kwargs) -> bool: + """Run action""" + + def create_prompt_input(init_role: "STRole", target_role: "STRole", retrieved: dict) -> str: + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + last_chat = init_role.rc.memory.get_last_chat(target_role.name) + last_chatted_time = "" + last_chat_about = "" + if last_chat: + last_chatted_time = last_chat.created.strftime("%B %d, %Y, %H:%M:%S") + last_chat_about = last_chat.description + + context = "" + for c_node in retrieved["events"]: + curr_desc = c_node.description.split(" ") + curr_desc[2:3] = ["was"] + curr_desc = " ".join(curr_desc) + context += f"{curr_desc}. " + context += "\n" + for c_node in retrieved["thoughts"]: + context += f"{c_node.description}. " + + curr_time = scratch.curr_time.strftime("%B %d, %Y, %H:%M:%S %p") + init_act_desc = scratch.act_description + if "(" in init_act_desc: + init_act_desc = init_act_desc.split("(")[-1][:-1] + + if len(scratch.planned_path) == 0 and "waiting" not in init_act_desc: + init_p_desc = f"{init_role.name} is already {init_act_desc}" + elif "waiting" in init_act_desc: + init_p_desc = f"{init_role.name} is {init_act_desc}" + else: + init_p_desc = f"{init_role.name} is on the way to {init_act_desc}" + + target_act_desc = scratch.act_description + if "(" in target_act_desc: + target_act_desc = target_act_desc.split("(")[-1][:-1] + + if len(target_scratch.planned_path) == 0 and "waiting" not in init_act_desc: + target_p_desc = f"{target_role.name} is already {target_act_desc}" + elif "waiting" in init_act_desc: + target_p_desc = f"{init_role.name} is {init_act_desc}" + else: + target_p_desc = f"{target_role.name} is on the way to {target_act_desc}" + + prompt_input = [] + prompt_input += [context] + + prompt_input += [curr_time] + + prompt_input += [init_role.name] + prompt_input += [target_role.name] + prompt_input += [last_chatted_time] + prompt_input += [last_chat_about] + + prompt_input += [init_p_desc] + prompt_input += [target_p_desc] + prompt_input += [init_role.name] + prompt_input += [target_role.name] + return prompt_input + + prompt_input = create_prompt_input(init_role, target_role, retrieved) + prompt = self.generate_prompt_with_tmpl_filename( + prompt_input=prompt_input, tmpl_filename="decide_to_talk_v2.txt" + ) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=20) # yes or no + result = True if output == "yes" else False + logger.info(f"Role: {init_role.name} Action: {self.cls_name} output: {result}") + return result diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/dummy_action.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/dummy_action.py new file mode 100644 index 0000000000000000000000000000000000000000..a5004d5ef36028e5761c270ee3c916ca9440ce3c --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/dummy_action.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : dummy action to make every STRole can deal DummyMessage which is caused by DummyAction + +from metagpt.actions import Action +from metagpt.schema import Message + + +class DummyAction(Action): + async def run(self, *args, **kwargs): + raise NotImplementedError + + +class DummyMessage(Message): + """ + dummy message to pass to role and make them to have a execution every round + """ + + content: str = "dummy" + cause_by: str = "DummyAction" diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_action_details.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_action_details.py new file mode 100644 index 0000000000000000000000000000000000000000..8e268a723a361217ed5e899d993526ded784d3af --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_action_details.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : gen_action_details + +import random + +from metagpt.environment.stanford_town.env_space import EnvObsParams, EnvObsType +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class GenActionSector(STAction): + name: str = "GenActionSector" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cleaned_response = llm_resp.split("}")[0] + return cleaned_response + + def _func_validate(self, llm_resp: str, prompt: str): + if len(llm_resp.strip()) < 1: + return False + if "}" not in llm_resp: + return False + if "," in llm_resp: + return False + return True + + def _func_fail_default_resp(self): + fs = "kitchen" + return fs + + async def run(self, role: "STRole", access_tile: dict[str, str], act_desp: str): + def create_prompt_input(role, access_tile: dict[str, str], act_desp): + act_world = f"{access_tile['world']}" + + prompt_input = [] + + prompt_input += [role.scratch.get_str_name()] + prompt_input += [role.scratch.living_area.split(":")[1]] + x = f"{act_world}:{role.scratch.living_area.split(':')[1]}" + prompt_input += [role.s_mem.get_str_accessible_sector_arenas(x)] + + prompt_input += [role.scratch.get_str_name()] + prompt_input += [f"{access_tile['sector']}"] + x = f"{act_world}:{access_tile['sector']}" + prompt_input += [role.s_mem.get_str_accessible_sector_arenas(x)] + + if role.scratch.get_str_daily_plan_req() != "": + prompt_input += [f"\n{role.scratch.get_str_daily_plan_req()}"] + else: + prompt_input += [""] + + # MAR 11 TEMP + prompt_input = [] + act_world = access_tile["world"] + accessible_sector_str = role.s_mem.get_str_accessible_sectors(act_world) + curr = accessible_sector_str.split(", ") + fin_accessible_sectors = [] + for i in curr: + if "'s house" in i: + if role.scratch.last_name in i: + fin_accessible_sectors += [i] + else: + fin_accessible_sectors += [i] + accessible_sector_str = ", ".join(fin_accessible_sectors) + # END MAR 11 TEMP + + prompt_input += [accessible_sector_str] + + act_desp_1 = act_desp + act_desp_2 = act_desp + if "(" in act_desp: + act_desp_1 = act_desp.split("(")[0].strip() + act_desp_2 = act_desp.split("(")[-1][:-1] + prompt_input += [role.scratch.get_str_name()] + prompt_input += [act_desp_1] + + prompt_input += [act_desp_2] + prompt_input += [role.scratch.get_str_name()] + return prompt_input + + prompt_template = "action_location_sector_v1.txt" + prompt_input = create_prompt_input(role, access_tile, act_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=15) + y = f"{access_tile['world']}" + x = [i.strip() for i in role.s_mem.get_str_accessible_sectors(y).split(",")] + if output not in x: + # output = random.choice(x) + output = role.scratch.living_area.split(":")[1] + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActionArena(STAction): + name: str = "GenActionArena" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cleaned_response = llm_resp.split("}")[0] + return cleaned_response + + def _func_validate(self, llm_resp: str, prompt: str): + if len(llm_resp.strip()) < 1: + return False + if "}" not in llm_resp: + return False + if "," in llm_resp: + return False + return True + + def _func_fail_default_resp(self): + fs = "kitchen" + return fs + + async def run(self, role: "STRole", act_desp: str, act_world: str, act_sector: str): + def create_prompt_input(role, act_desp, act_world, act_sector): + prompt_input = [] + prompt_input += [role.scratch.get_str_name()] + x = f"{act_world}:{act_sector}" + prompt_input += [act_sector] + + # MAR 11 TEMP + accessible_arena_str = role.s_mem.get_str_accessible_sector_arenas(x) + curr = accessible_arena_str.split(", ") + fin_accessible_arenas = [] + for i in curr: + if "'s room" in i: + if role.scratch.last_name in i: + fin_accessible_arenas += [i] + else: + fin_accessible_arenas += [i] + accessible_arena_str = ", ".join(fin_accessible_arenas) + # END MAR 11 TEMP + prompt_input += [accessible_arena_str] + act_desp_1 = act_desp + act_desp_2 = act_desp + if "(" in act_desp: + act_desp_1 = act_desp.split("(")[0].strip() + act_desp_2 = act_desp.split("(")[-1][:-1] + prompt_input += [role.scratch.get_str_name()] + prompt_input += [act_desp_1] + + prompt_input += [act_desp_2] + prompt_input += [role.scratch.get_str_name()] + + prompt_input += [act_sector] + prompt_input += [accessible_arena_str] + return prompt_input + + prompt_template = "action_location_object_vMar11.txt" + prompt_input = create_prompt_input(role, act_desp, act_world, act_sector) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=15) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActionObject(STAction): + name: str = "GenActionObject" + + def _func_validate(self, llm_resp: str, prompt: str): + if len(llm_resp.strip()) < 1: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str): + cleaned_response = llm_resp.strip() + return cleaned_response + + def _func_fail_default_resp(self): + fs = "bed" + return fs + + async def run(self, role: "STRole", act_desp: str, temp_address: str): + def create_prompt_input(role, act_desp, temp_address): + prompt_input = [] + if "(" in act_desp: + act_desp = act_desp.split("(")[-1][:-1] + + prompt_input += [act_desp] + prompt_input += [role.s_mem.get_str_accessible_arena_game_objects(temp_address)] + return prompt_input + + prompt_template = "action_object_v2.txt" + prompt_input = create_prompt_input(role, act_desp, temp_address) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=15) + x = [i.strip() for i in role.s_mem.get_str_accessible_arena_game_objects(temp_address).split(",")] + if output not in x: + output = random.choice(x) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenPronunciatio(STAction): + name: str = "GenPronunciatio" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + if len(cr) > 3: + cr = cr[:3] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) == 0: + return False + except Exception: + return False + return True + + def _func_fail_default_resp(self): + fs = "😋" + return fs + + async def run(self, role: "STRole", act_desp: str): + def create_prompt_input(act_desp): + if "(" in act_desp: + act_desp = act_desp.split("(")[-1].split(")")[0] + prompt_input = [act_desp] + return prompt_input + + prompt_template = "generate_pronunciatio_v1.txt" + prompt_input = create_prompt_input(act_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + example_output = "🛁🧖‍♀️" + special_instruction = "The value for the output must ONLY contain the emojis." + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenEventTriple(STAction): + name: str = "GenEventTriple" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except Exception: + return False + return True + + def _func_fail_default_resp(self, role): + fs = (role.name, "is", "idle") + return fs + + async def run(self, role: "STRole", act_desp: str): + def create_prompt_input(role, act_desp): + if "(" in act_desp: + act_desp = act_desp.split("(")[-1].split(")")[0] + prompt_input = [role.name, act_desp, role.name] + return prompt_input + + prompt_template = "generate_event_triple_v1.txt" + prompt_input = create_prompt_input(role, act_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp(role) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=30) + output = (role.name, output[0], output[1]) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActObjDescription(STAction): + name: str = "GenActObjDescription" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + if cr[-1] == ".": + cr = cr[:-1] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_fail_default_resp(self, act_game_object): + fs = f"{act_game_object} is idle" + return fs + + async def run(self, role: "STRole", act_game_object: str, act_desp: str): + def create_prompt_input(act_game_object, act_desp, role): + prompt_input = [act_game_object, role.name, act_desp, act_game_object, act_game_object] + return prompt_input + + prompt_template = "generate_obj_event_v1.txt" + prompt_input = create_prompt_input(act_game_object, act_desp, role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + example_output = "being fixed" + special_instruction = "The output should ONLY contain the phrase that should go in ." + self.fail_default_resp = self._func_fail_default_resp(act_game_object) + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenObjEventTriple(STAction): + name: str = "GenObjEventTriple" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except Exception: + return False + return True + + def _func_fail_default_resp(self, act_game_object: str): + fs = (act_game_object, "is", "idle") + return fs + + async def run(self, role: "STRole", act_game_object, act_obj_desp): + def create_prompt_input(act_game_object, act_obj_desp): + prompt_input = [act_game_object, act_obj_desp, act_game_object] + return prompt_input + + prompt_template = "generate_event_triple_v1.txt" + prompt_input = create_prompt_input(act_game_object, act_obj_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp(act_game_object) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=30) + output = (act_game_object, output[0], output[1]) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActionDetails(STAction): + name: str = "GenActionDetails" + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + pass + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + # TODO -- this sometimes generates error + try: + self._func_cleanup(llm_resp) + except Exception: + return False + return True + + def _func_fail_default_resp(self): + fs = {} + return fs + + async def run(self, role: "STRole", act_desp: str, act_dura): + access_tile = role.rc.env.observe( + obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=role.scratch.curr_tile) + ) + act_world = access_tile["world"] + act_sector = await GenActionSector().run(role, access_tile, act_desp) + act_arena = await GenActionArena().run(role, act_desp, act_world, act_sector) + act_address = f"{act_world}:{act_sector}:{act_arena}" + if not role.s_mem.get_str_accessible_arena_game_objects(act_address): + act_game_object = "" + else: + act_game_object = await GenActionObject().run(role, act_desp, act_address) + new_address = f"{act_world}:{act_sector}:{act_arena}:{act_game_object}" + act_pron = await GenPronunciatio().run(role, act_desp) + act_event = await GenEventTriple().run(role, act_desp) + # Persona's actions also influence the object states. We set those up here. + act_obj_desp = await GenActObjDescription().run(role, act_game_object, act_desp) + act_obj_pron = await GenPronunciatio().run(role, act_obj_desp) + act_obj_event = await GenObjEventTriple().run(role, act_game_object, act_obj_desp) + result_dict = { + "action_address": new_address, + "action_duration": int(act_dura), + "action_description": act_desp, + "action_pronunciatio": act_pron, + "action_event": act_event, + "chatting_with": None, + "chat": None, + "chatting_with_buffer": None, + "chatting_end_time": None, + "act_obj_description": act_obj_desp, + "act_obj_pronunciatio": act_obj_pron, + "act_obj_event": act_obj_event, + } + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {result_dict}") + return result_dict diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_daily_schedule.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_daily_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..5dffa8995260467c76f9e9810eefd748f960d334 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_daily_schedule.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : gen_daily_schedule + + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class GenDailySchedule(STAction): + name: str = "GenDailySchedule" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + cr = [] + _cr = llm_resp.split(")") + for i in _cr: + if i[-1].isdigit(): + i = i[:-1].strip() + if i[-1] == "." or i[-1] == ",": + cr += [i[:-1].strip()] + return cr + + def _func_fail_default_resp(self) -> int: + fs = [ + "wake up and complete the morning routine at 6:00 am", + "eat breakfast at 7:00 am", + "read a book from 8:00 am to 12:00 pm", + "have lunch at 12:00 pm", + "take a nap from 1:00 pm to 4:00 pm", + "relax and watch TV from 7:00 pm to 8:00 pm", + "go to bed at 11:00 pm", + ] + return fs + + async def run(self, role: "STRole", wake_up_hour: str): + def create_prompt_input(role, wake_up_hour): + prompt_input = [] + prompt_input += [role.scratch.get_str_iss()] + prompt_input += [role.scratch.get_str_lifestyle()] + prompt_input += [role.scratch.get_str_curr_date_str()] + prompt_input += [role.scratch.get_str_firstname()] + prompt_input += [f"{str(wake_up_hour)}:00 am"] + return prompt_input + + wake_up_hour = int(wake_up_hour) + prompt_template = "daily_planning_v6.txt" + prompt_input = create_prompt_input(role, wake_up_hour) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=500) + output = [f"wake up and complete the morning routine at {wake_up_hour}:00 am"] + output + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_hourly_schedule.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_hourly_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..5d59f96ddaa81f918d324933df70ebabcf6fb634 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_hourly_schedule.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : gen_hourly_schedule + +import random +import string + +from metagpt.logs import logger + +from .st_action import STAction + + +def get_random_alphanumeric(i=6, j=6): + """ + Returns a random alpha numeric strength that has the length of somewhere + between i and j. + + INPUT: + i: min_range for the length + j: max_range for the length + OUTPUT: + an alpha numeric str with the length of somewhere between i and j. + """ + k = random.randint(i, j) + x = "".join(random.choices(string.ascii_letters + string.digits, k=k)) + return x + + +class GenHourlySchedule(STAction): + name: str = "GenHourlySchedule" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + cr = llm_resp.strip() + if cr[-1] == ".": + cr = cr[:-1] + # to only use the first line of output + cr = cr.split("\n")[0] + return cr + + def _func_fail_default_resp(self) -> int: + fs = "asleep" + return fs + + async def _generate_schedule_for_given_hour( + self, role: "STRole", curr_hour_str, p_f_ds_hourly_org, hour_str, intermission2=None + ): + def create_prompt_input(persona, curr_hour_str, p_f_ds_hourly_org, hour_str, intermission2=None): + schedule_format = "" + for i in hour_str: + schedule_format += f"[{persona.scratch.get_str_curr_date_str()} -- {i}]" + schedule_format += " Activity: [Fill in]\n" + schedule_format = schedule_format[:-1] + + intermission_str = "Here the originally intended hourly breakdown of" + intermission_str += f" {persona.scratch.get_str_firstname()}'s schedule today: " + for count, i in enumerate(persona.scratch.daily_req): + intermission_str += f"{str(count + 1)}) {i}, " + intermission_str = intermission_str[:-2] + + prior_schedule = "" + if p_f_ds_hourly_org: + prior_schedule = "\n" + for count, i in enumerate(p_f_ds_hourly_org): + prior_schedule += f"[(ID:{get_random_alphanumeric()})" + prior_schedule += f" {persona.scratch.get_str_curr_date_str()} --" + prior_schedule += f" {hour_str[count]}] Activity:" + prior_schedule += f" {persona.scratch.get_str_firstname()}" + prior_schedule += f" is {i}\n" + + prompt_ending = f"[(ID:{get_random_alphanumeric()})" + prompt_ending += f" {persona.scratch.get_str_curr_date_str()}" + prompt_ending += f" -- {curr_hour_str}] Activity:" + prompt_ending += f" {persona.scratch.get_str_firstname()} is" + + if intermission2: + intermission2 = f"\n{intermission2}" + + prompt_input = [] + prompt_input += [schedule_format] + prompt_input += [persona.scratch.get_str_iss()] + + prompt_input += [prior_schedule + "\n"] + prompt_input += [intermission_str] + if intermission2: + prompt_input += [intermission2] + else: + prompt_input += [""] + prompt_input += [prompt_ending] + + return prompt_input + + prompt_template = "generate_hourly_schedule_v2.txt" + prompt_input = create_prompt_input(role, curr_hour_str, p_f_ds_hourly_org, hour_str, intermission2) + prompt_input_str = "\n".join(prompt_input) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=50) + logger.info( + f"Role: {role.name} _generate_schedule_for_given_hour prompt_input: {prompt_input_str}, " + f"output: {output}" + ) + return output + + async def run(self, role: "STRole", wake_up_hour: int): + hour_str = [ + "00:00 AM", + "01:00 AM", + "02:00 AM", + "03:00 AM", + "04:00 AM", + "05:00 AM", + "06:00 AM", + "07:00 AM", + "08:00 AM", + "09:00 AM", + "10:00 AM", + "11:00 AM", + "12:00 PM", + "01:00 PM", + "02:00 PM", + "03:00 PM", + "04:00 PM", + "05:00 PM", + "06:00 PM", + "07:00 PM", + "08:00 PM", + "09:00 PM", + "10:00 PM", + "11:00 PM", + ] + n_m1_activity = [] + diversity_repeat_count = 1 # TODO mg 1->3 + for i in range(diversity_repeat_count): + logger.info(f"diversity_repeat_count idx: {i}") + n_m1_activity_set = set(n_m1_activity) + if len(n_m1_activity_set) < 5: + n_m1_activity = [] + for count, curr_hour_str in enumerate(hour_str): + if wake_up_hour > 0: + n_m1_activity += ["sleeping"] + wake_up_hour -= 1 + else: + logger.info(f"_generate_schedule_for_given_hour idx: {count}, n_m1_activity: {n_m1_activity}") + n_m1_activity += [ + await self._generate_schedule_for_given_hour(role, curr_hour_str, n_m1_activity, hour_str) + ] + + # Step 1. Compressing the hourly schedule to the following format: + # The integer indicates the number of hours. They should add up to 24. + # [['sleeping', 6], ['waking up and starting her morning routine', 1], + # ['eating breakfast', 1], ['getting ready for the day', 1], + # ['working on her painting', 2], ['taking a break', 1], + # ['having lunch', 1], ['working on her painting', 3], + # ['taking a break', 2], ['working on her painting', 2], + # ['relaxing and watching TV', 1], ['going to bed', 1], ['sleeping', 2]] + _n_m1_hourly_compressed = [] + prev = None + prev_count = 0 + for i in n_m1_activity: + if i != prev: + prev_count = 1 + _n_m1_hourly_compressed += [[i, prev_count]] + prev = i + elif _n_m1_hourly_compressed: + _n_m1_hourly_compressed[-1][1] += 1 + + # Step 2. Expand to min scale (from hour scale) + # [['sleeping', 360], ['waking up and starting her morning routine', 60], + # ['eating breakfast', 60],.. + n_m1_hourly_compressed = [] + for task, duration in _n_m1_hourly_compressed: + n_m1_hourly_compressed += [[task, duration * 60]] + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {n_m1_hourly_compressed}") + return n_m1_hourly_compressed diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_iter_chat_utt.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_iter_chat_utt.py new file mode 100644 index 0000000000000000000000000000000000000000..40f6d3af0ed87d5a030a8e7594b297014814bc54 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/gen_iter_chat_utt.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : generate_iterative_chat_utt + +from metagpt.environment.stanford_town.env_space import EnvObsParams, EnvObsType +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.ext.stanford_town.utils.utils import extract_first_json_dict +from metagpt.logs import logger + + +class GenIterChatUTT(STAction): + name: str = "GenIterChatUTT" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + _ = extract_first_json_dict(llm_resp) + resp = True + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> dict: + gpt_response = extract_first_json_dict(llm_resp) + + cleaned_dict = dict() + cleaned = [] + for key, val in gpt_response.items(): + cleaned += [val] + cleaned_dict["utterance"] = cleaned[0] + cleaned_dict["end"] = True + if "f" in str(cleaned[1]) or "F" in str(cleaned[1]): + cleaned_dict["end"] = False + + return cleaned_dict + + def _func_fail_default_resp(self) -> dict: + cleaned_dict = dict() + cleaned_dict["utterance"] = "..." + cleaned_dict["end"] = False + return cleaned_dict + + async def run( + self, + init_role: "STRole", + target_role: "STRole", + retrieved: dict, + curr_context: str, + curr_chat: list[str], + *args, + **kwargs, + ) -> dict: + def create_prompt_input( + access_tile: dict[str, str], + init_role: "STRole", + target_role: "STRole", + retrieved: dict, + curr_context: str, + curr_chat: list[str], + ): + role = init_role + scratch = role.rc.scratch + target_scratch = target_role.rc.scratch + prev_convo_insert = "\n" + if role.rc.memory.chat_list: + for i in role.rc.memory.chat_list: + if i.object == target_role.name: + v1 = int((scratch.curr_time - i.created).total_seconds() / 60) + prev_convo_insert += ( + f"{str(v1)} minutes ago, {scratch.name} and " + f"{target_scratch.name} were already {i.description} " + f"This context takes place after that conversation." + ) + break + if prev_convo_insert == "\n": + prev_convo_insert = "" + if role.rc.memory.chat_list: + if int((scratch.curr_time - role.rc.memory.chat_list[-1].created).total_seconds() / 60) > 480: + prev_convo_insert = "" + logger.info(f"prev_convo_insert: {prev_convo_insert}") + + curr_sector = f"{access_tile['sector']}" + curr_arena = f"{access_tile['arena']}" + curr_location = f"{curr_arena} in {curr_sector}" + + retrieved_str = "" + for key, vals in retrieved.items(): + for v in vals: + retrieved_str += f"- {v.description}\n" + + convo_str = "" + for i in curr_chat: + convo_str += ": ".join(i) + "\n" + if convo_str == "": + convo_str = "[The conversation has not started yet -- start it!]" + + init_iss = f"Here is Here is a brief description of {scratch.name}.\n{scratch.get_str_iss()}" + prompt_input = [ + init_iss, + scratch.name, + retrieved_str, + prev_convo_insert, + curr_location, + curr_context, + scratch.name, + target_scratch.name, + convo_str, + scratch.name, + target_scratch.name, + scratch.name, + scratch.name, + scratch.name, + ] + return prompt_input + + access_tile = init_role.rc.env.observe( + obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=init_role.scratch.curr_tile) + ) + prompt_input = create_prompt_input(access_tile, init_role, target_role, retrieved, curr_context, curr_chat) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "iterative_convo_v1.txt") + # original using `ChatGPT_safe_generate_response_OLD` + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_wo_extra_prompt(prompt) + logger.info(f"Role: {init_role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/inner_voice_action.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/inner_voice_action.py new file mode 100644 index 0000000000000000000000000000000000000000..83cfa037ba8de69309a1f9438509b2a5ec8de8b6 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/inner_voice_action.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class AgentWhisperThoughtAction(STAction): + name: str = "AgentWhisperThoughtAction" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> list: + return llm_resp.split('"')[0].strip() + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role: "STRole", statements, test_input=None): + prompt_input = [role.scratch.name, statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "whisper_inner_thought_v1.txt") + + output = await self._run_gpt35_max_tokens(prompt, max_tokens=50) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/new_decomp_schedule.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/new_decomp_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..759ec170f464622a304aac85269efe043568d16d --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/new_decomp_schedule.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : new_decomp_schedule + +import datetime + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class NewDecompSchedule(STAction): + name: str = "NewDecompSchedule" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + llm_resp = self._func_cleanup(llm_resp, prompt) + dur_sum = 0 + for act, dur in llm_resp: + dur_sum += dur + if isinstance(act, str): + return False + if isinstance(dur, int): + return False + x = prompt.split("\n")[0].split("originally planned schedule from")[-1].strip()[:-1] + x = [datetime.datetime.strptime(i.strip(), "%H:%M %p") for i in x.split(" to ")] + delta_min = int((x[1] - x[0]).total_seconds() / 60) + + if int(dur_sum) != int(delta_min): + return False + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + new_schedule = prompt + " " + llm_resp.strip() + new_schedule = new_schedule.split("The revised schedule:")[-1].strip() + new_schedule = new_schedule.split("\n") + + ret_temp = [] + for i in new_schedule: + ret_temp += [i.split(" -- ")] + + ret = [] + for time_str, action in ret_temp: + start_time = time_str.split(" ~ ")[0].strip() + end_time = time_str.split(" ~ ")[1].strip() + delta = datetime.datetime.strptime(end_time, "%H:%M") - datetime.datetime.strptime(start_time, "%H:%M") + delta_min = int(delta.total_seconds() / 60) + if delta_min < 0: + delta_min = 0 + ret += [[action, delta_min]] + + return ret + + def _func_fail_default_resp(self, main_act_dur: int, truncated_act_dur: int) -> int: + dur_sum = 0 + for act, dur in main_act_dur: + dur_sum += dur + + ret = truncated_act_dur[:] + ret += main_act_dur[len(ret) - 1 :] + + # If there are access, we need to trim... + ret_dur_sum = 0 + count = 0 + over = None + for act, dur in ret: + ret_dur_sum += dur + if ret_dur_sum == dur_sum: + break + if ret_dur_sum > dur_sum: + over = ret_dur_sum - dur_sum + break + count += 1 + + if over: + ret = ret[: count + 1] + ret[-1][1] -= over + + return ret + + async def run( + self, + role: "STRole", + main_act_dur: int, + truncated_act_dur: int, + start_time_hour: datetime, + end_time_hour: datetime, + inserted_act: str, + inserted_act_dur: int, + *args, + **kwargs, + ): + def create_prompt_input( + role: "STRole", + main_act_dur: int, + truncated_act_dur: int, + start_time_hour: datetime, + end_time_hour: datetime, + inserted_act: str, + inserted_act_dur: int, + ): + persona_name = role.name + start_hour_str = start_time_hour.strftime("%H:%M %p") + end_hour_str = end_time_hour.strftime("%H:%M %p") + + original_plan = "" + for_time = start_time_hour + for i in main_act_dur: + original_plan += ( + f'{for_time.strftime("%H:%M")} ~ ' + f'{(for_time + datetime.timedelta(minutes=int(i[1]))).strftime("%H:%M")} -- ' + i[0] + ) + original_plan += "\n" + for_time += datetime.timedelta(minutes=int(i[1])) + + new_plan_init = "" + for_time = start_time_hour + for count, i in enumerate(truncated_act_dur): + new_plan_init += ( + f'{for_time.strftime("%H:%M")} ~ ' + f'{(for_time + datetime.timedelta(minutes=int(i[1]))).strftime("%H:%M")} -- ' + i[0] + ) + new_plan_init += "\n" + if count < len(truncated_act_dur) - 1: + for_time += datetime.timedelta(minutes=int(i[1])) + + new_plan_init += (for_time + datetime.timedelta(minutes=int(i[1]))).strftime("%H:%M") + " ~" + + prompt_input = [ + persona_name, + start_hour_str, + end_hour_str, + original_plan, + persona_name, + inserted_act, + inserted_act_dur, + persona_name, + start_hour_str, + end_hour_str, + end_hour_str, + new_plan_init, + ] + return prompt_input + + prompt_input = create_prompt_input( + role, main_act_dur, truncated_act_dur, start_time_hour, end_time_hour, inserted_act, inserted_act_dur + ) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "new_decomp_schedule_v1.txt") + self.fail_default_resp = self._func_fail_default_resp(main_act_dur, truncated_act_dur) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=1000) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/run_reflect_action.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/run_reflect_action.py new file mode 100644 index 0000000000000000000000000000000000000000..895f6828f03d629b08f8b8f2207250ce1940c736 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/run_reflect_action.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Integration Reflect Action + +import re + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +# Run GPT Prompt Focal Point method +class AgentFocusPt(STAction): + name: str = "AgentFocusPt" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + try: + """ + Cleanup handling has been completed for run_v2 + """ + return llm_resp + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, n: int, test_input=None) -> str: + def create_prompt_input(role: "STRole", statements, n, test_input=None): + prompt_input = [statements, str(n)] + return prompt_input + + prompt_input = create_prompt_input(role, statements, n) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "generate_focal_pt_v1.txt") + + example_output = '["What should Jane do for lunch", "Does Jane like strawberry", "Who is Jane"]' + special_instruction = "Output must be a list of str." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Insight and Guidance +class AgentInsightAndGuidance(STAction): + name: str = "AgentInsightAndGuidance" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> dict: + try: + llm_resp = "1. " + llm_resp.strip() + ret = dict() + for i in llm_resp.split("\n"): + row = " ".join(i.split(". ")[1:]) + if "(because of " not in row: + continue + thought = row.split("(because of ")[0].strip() + if ")" not in row.split("(because of ")[1]: + continue + evi_raw = row.split("(because of ")[1].split(")")[0].strip() + evi_raw = re.findall(r"\d+", evi_raw) + evi_raw = [int(i.strip()) for i in evi_raw] + ret[thought] = evi_raw + return ret + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self, n: int) -> str: + return ["I am hungry"] * n + + async def run(self, role: "STRole", statements: str, n: int, test_input=None) -> dict: + def create_prompt_input(role, statements, n, test_input=None): + prompt_input = [statements, str(n)] + return prompt_input + + prompt_input = create_prompt_input(role, statements, n) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "insight_and_evidence_v1.txt") + + self.fail_default_resp = self._func_fail_default_resp(n) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=150) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Event Triple +class AgentEventTriple(STAction): + name: str = "AgentEventTriple" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> list: + try: + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + if len(cr) != 2: + return cr[-2:] + return cr + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, statements: str, role: "STRole", verbose=False) -> tuple: + def create_prompt_input(statements, role): + if "(" in statements: + statements = statements.split("(")[-1].split(")")[0] + prompt_input = [role.scratch.name, statements, role.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(statements, role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "generate_event_triple_v1.txt") + + output = await self._run_gpt35_max_tokens(prompt, max_tokens=30) + output = (role.scratch.name, output[0], output[1]) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Event Poignancy +class AgentEventPoignancy(STAction): + name: str = "AgentEventPoignancy" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> int: + try: + llm_resp = int(llm_resp.strip()) + return llm_resp + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role: "STRole", statements: str, test_input=None): + prompt_input = [role.scratch.name, role.scratch.get_str_iss(), role.scratch.name, statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "poignancy_event_v1.txt") + + example_output = "5" # ######## + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Chat Poignancy +class AgentChatPoignancy(STAction): + name: str = "AgentChatPoignancy" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> int: + try: + llm_resp = int(llm_resp.strip()) + return llm_resp + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role: "STRole", statements, test_input=None): + prompt_input = [role.scratch.name, role.scratch.get_str_iss(), role.scratch.name, statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "poignancy_chat_v1.txt") + + example_output = "5" # ######## + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Planning Thought on Convo +class AgentPlanThoughtOnConvo(STAction): + name: str = "AgentPlanThoughtOnConvo" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + try: + return llm_resp.split('"')[0].strip() + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role, statements, test_input=None): + prompt_input = [statements, role.scratch.name, role.scratch.name, role.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "planning_thought_on_convo_v1.txt") + + output = await self._run_gpt35_max_tokens(prompt, max_tokens=50) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Memory on Convo +class AgentMemoryOnConvo(STAction): + name: str = "AgentMemoryOnConvo" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + try: + return llm_resp.split('"')[0].strip() + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role, statements, test_input=None): + prompt_input = [statements, role.scratch.name, role.scratch.name, role.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "memo_on_convo_v1.txt") + example_output = "Jane Doe was interesting to talk to." + special_instruction = ( + "The output should ONLY contain a string that summarizes anything interesting " + "that the agent may have noticed" + ) + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/st_action.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/st_action.py new file mode 100644 index 0000000000000000000000000000000000000000..321676374d1106e41dc2a7e9ada8b14e827a5113 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/st_action.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : StanfordTown Action +import json +import time +from abc import abstractmethod +from pathlib import Path +from typing import Any, Optional, Union + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.ext.stanford_town.utils.const import PROMPTS_DIR +from metagpt.logs import logger + + +class STAction(Action): + name: str = "STAction" + prompt_dir: Path = PROMPTS_DIR + fail_default_resp: Optional[str] = None + + @property + def cls_name(self): + return self.__class__.__name__ + + @abstractmethod + def _func_validate(self, llm_resp: str, prompt: str): + raise NotImplementedError + + @abstractmethod + def _func_cleanup(self, llm_resp: str, prompt: str): + raise NotImplementedError + + @abstractmethod + def _func_fail_default_resp(self): + raise NotImplementedError + + def generate_prompt_with_tmpl_filename(self, prompt_input: Union[str, list], tmpl_filename) -> str: + """ + same with `generate_prompt` + Args: + prompt_input: the input we want to feed in (IF THERE ARE MORE THAN ONE INPUT, THIS CAN BE A LIST.) + tmpl_filename: prompt template filename + Returns: + a str prompt that will be sent to LLM server. + """ + if isinstance(prompt_input, str): + prompt_input = [prompt_input] + prompt_input = [str(i) for i in prompt_input] + + f = open(str(self.prompt_dir.joinpath(tmpl_filename)), "r") + prompt = f.read() + f.close() + for count, i in enumerate(prompt_input): + prompt = prompt.replace(f"!!", i) + if "###" in prompt: + prompt = prompt.split("###")[1] + return prompt.strip() + + async def _aask(self, prompt: str) -> str: + return await self.llm.aask(prompt) + + async def _run_gpt35_max_tokens(self, prompt: str, max_tokens: int = 50, retry: int = 3): + for idx in range(retry): + try: + tmp_max_tokens_rsp = getattr(config.llm, "max_token", 1500) + setattr(config.llm, "max_token", max_tokens) + self.llm.use_system_prompt = False # to make it behave like a non-chat completions + + llm_resp = await self._aask(prompt) + + setattr(config.llm, "max_token", tmp_max_tokens_rsp) + logger.info(f"Action: {self.cls_name} llm _run_gpt35_max_tokens raw resp: {llm_resp}") + if self._func_validate(llm_resp, prompt): + return self._func_cleanup(llm_resp, prompt) + except Exception as exp: + logger.warning(f"Action: {self.cls_name} _run_gpt35_max_tokens exp: {exp}") + time.sleep(5) + return self.fail_default_resp + + async def _run_gpt35( + self, prompt: str, example_output: str, special_instruction: str, retry: int = 3 + ) -> Union[bool, Any]: + """same with `gpt_structure.ChatGPT_safe_generate_response`""" + prompt = '"""\n' + prompt + '\n"""\n' + prompt += f"Output the response to the prompt above in json. {special_instruction}\n" + prompt += "Example output json:\n" + prompt += '{"output": "' + str(example_output) + '"}' + + for idx in range(retry): + try: + llm_resp = await self._aask(prompt) + logger.info(f"Action: {self.cls_name} llm _run_gpt35 raw resp: {llm_resp}") + end_idx = llm_resp.strip().rfind("}") + 1 + llm_resp = llm_resp[:end_idx] + llm_resp = json.loads(llm_resp)["output"] + + if self._func_validate(llm_resp, prompt): + return self._func_cleanup(llm_resp, prompt) + except Exception as exp: + logger.warning(f"Action: {self.cls_name} _run_gpt35 exp: {exp}") + time.sleep(5) # usually avoid `Rate limit` + return False + + async def _run_gpt35_wo_extra_prompt(self, prompt: str, retry: int = 3) -> str: + for idx in range(retry): + try: + llm_resp = await self._aask(prompt) + llm_resp = llm_resp.strip() + logger.info(f"Action: {self.cls_name} llm _run_gpt35_wo_extra_prompt raw resp: {llm_resp}") + if self._func_validate(llm_resp, prompt): + return self._func_cleanup(llm_resp, prompt) + except Exception as exp: + logger.warning(f"Action: {self.cls_name} _run_gpt35_wo_extra_prompt exp: {exp}") + time.sleep(5) # usually avoid `Rate limit` + return self.fail_default_resp + + async def run(self, *args, **kwargs): + """Run action""" + raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/summarize_conv.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/summarize_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5be5fcaa4381b55946f1208ec54f356ca74922a4 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/summarize_conv.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : summarize the content of agents' conversation + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class SummarizeConv(STAction): + name: str = "SummarizeConv" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + _ = self._func_cleanup(llm_resp, prompt) + resp = True + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> str: + ret = "conversing about " + llm_resp.strip() + return ret + + def _func_fail_default_resp(self) -> str: + return "conversing with a housemate about morning greetings" + + async def run(self, conv: list): + def create_prompt_input(conversation: list): + convo_str = "" + for row in conversation: + convo_str += f'{row[0]}: "{row[1]}"\n' + prompt_input = [convo_str] + return prompt_input + + prompt_input = create_prompt_input(conv) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "summarize_conversation_v1.txt") + + example_output = "conversing about what to eat for lunch" + special_instruction = ( + "The output must continue the sentence above by filling in the tag. " + "Don't start with 'this is a conversation about...' Just finish the sentence " + "but do not miss any important details (including who are chatting)." + ) + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/task_decomp.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/task_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..3a23a73456e190811609f095c0552cc80050932b --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/task_decomp.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : task_decomp + +import datetime + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class TaskDecomp(STAction): + name: str = "TaskDecomp" + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + # TODO SOMETHING HERE sometimes fails... See screenshot + temp = [i.strip() for i in llm_resp.split("\n")] + _cr = [] + cr = [] + for count, i in enumerate(temp): + if count != 0: + _cr += [" ".join([j.strip() for j in i.split(" ")][3:])] + else: + _cr += [i] + for count, i in enumerate(_cr): + k = [j.strip() for j in i.split("(duration in minutes:")] + task = k[0] + if task[-1] == ".": + task = task[:-1] + duration = int(k[1].split(",")[0].strip()) + cr += [[task, duration]] + + total_expected_min = int(prompt.split("(total duration in minutes")[-1].split("):")[0].strip()) + + # TODO -- now, you need to make sure that this is the same as the sum of + # the current action sequence. + curr_min_slot = [ + ["dummy", -1], + ] # (task_name, task_index) + for count, i in enumerate(cr): + i_task = i[0] + i_duration = i[1] + + i_duration -= i_duration % 5 + if i_duration > 0: + for j in range(i_duration): + curr_min_slot += [(i_task, count)] + curr_min_slot = curr_min_slot[1:] + + if len(curr_min_slot) > total_expected_min: + last_task = curr_min_slot[60] + for i in range(1, 6): + curr_min_slot[-1 * i] = last_task + elif len(curr_min_slot) < total_expected_min: + last_task = curr_min_slot[-1] + for i in range(total_expected_min - len(curr_min_slot)): + curr_min_slot += [last_task] + + cr_ret = [ + ["dummy", -1], + ] + for task, task_index in curr_min_slot: + if task != cr_ret[-1][0]: + cr_ret += [[task, 1]] + else: + cr_ret[-1][1] += 1 + cr = cr_ret[1:] + + return cr + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + # TODO -- this sometimes generates error + try: + self._func_cleanup(llm_resp, prompt) + except Exception: + return False + return True + + def _func_fail_default_resp(self) -> int: + fs = [["asleep", 0]] + return fs + + async def run(self, role: "STRole", task_desc: int, truncated_act_dur: int, *args, **kwargs): + def create_prompt_input(role, task, duration): + """ + Today is Saturday June 25. From 00:00 ~ 06:00am, Maeve is + planning on sleeping, 06:00 ~ 07:00am, Maeve is + planning on waking up and doing her morning routine, + and from 07:00am ~08:00am, Maeve is planning on having breakfast. + """ + + curr_f_org_index = role.scratch.get_f_daily_schedule_hourly_org_index() + all_indices = [] + # if curr_f_org_index > 0: + # all_indices += [curr_f_org_index-1] + all_indices += [curr_f_org_index] + if curr_f_org_index + 1 <= len(role.scratch.f_daily_schedule_hourly_org): + all_indices += [curr_f_org_index + 1] + if curr_f_org_index + 2 <= len(role.scratch.f_daily_schedule_hourly_org): + all_indices += [curr_f_org_index + 2] + + curr_time_range = "" + + logger.debug("DEBUG") + logger.debug(role.scratch.f_daily_schedule_hourly_org) + logger.debug(all_indices) + + summ_str = f'Today is {role.scratch.curr_time.strftime("%B %d, %Y")}. ' + summ_str += "From " + for index in all_indices: + logger.debug(f"index {index}") + if index < len(role.scratch.f_daily_schedule_hourly_org): + start_min = 0 + for i in range(index): + start_min += role.scratch.f_daily_schedule_hourly_org[i][1] + end_min = start_min + role.scratch.f_daily_schedule_hourly_org[index][1] + start_time = datetime.datetime.strptime("00:00:00", "%H:%M:%S") + datetime.timedelta( + minutes=start_min + ) + end_time = datetime.datetime.strptime("00:00:00", "%H:%M:%S") + datetime.timedelta( + minutes=end_min + ) + start_time_str = start_time.strftime("%H:%M%p") + end_time_str = end_time.strftime("%H:%M%p") + summ_str += ( + f"{start_time_str} ~ {end_time_str}, {role.name} is planning " + f"on {role.scratch.f_daily_schedule_hourly_org[index][0]}, " + ) + if curr_f_org_index + 1 == index: + curr_time_range = f"{start_time_str} ~ {end_time_str}" + summ_str = summ_str[:-2] + "." + + prompt_input = [] + prompt_input += [role.scratch.get_str_iss()] + prompt_input += [summ_str] + # prompt_input += [role.scratch.get_str_curr_date_str()] + prompt_input += [role.scratch.get_str_firstname()] + prompt_input += [role.scratch.get_str_firstname()] + prompt_input += [task] + prompt_input += [curr_time_range] + prompt_input += [duration] + prompt_input += [role.scratch.get_str_firstname()] + return prompt_input + + prompt_input = create_prompt_input(role, task_desc, truncated_act_dur) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "task_decomp_v3.txt") + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=1000) + logger.info(f"Role: {role.name} {self.cls_name} output: {output}") + + fin_output = [] + time_sum = 0 + for i_task, i_duration in output: + time_sum += i_duration + # HM????????? + # if time_sum < duration: + if time_sum <= truncated_act_dur: + fin_output += [[i_task, i_duration]] + else: + break + ftime_sum = 0 + for fi_task, fi_duration in fin_output: + ftime_sum += fi_duration + + fin_output[-1][1] += truncated_act_dur - ftime_sum + output = fin_output + + task_decomp = output + ret = [] + for decomp_task, duration in task_decomp: + ret += [[f"{task_desc} ({decomp_task})", duration]] + output = ret + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/actions/wake_up.py b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/wake_up.py new file mode 100644 index 0000000000000000000000000000000000000000..ea44cd3a427d3526cd673e7619ca462db870d873 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/actions/wake_up.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : wake_up + + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class WakeUp(STAction): + name: str = "WakeUp" + + def _func_validate(self, llm_resp: str, prompt: str = None) -> bool: + try: + self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str) -> int: + cr = int(llm_resp.strip().lower().split("am")[0]) + return cr + + def _func_fail_default_resp(self) -> int: + fs = 8 + return fs + + async def run(self, role: "STRole"): + def create_prompt_input(role): + prompt_input = [ + role.scratch.get_str_iss(), + role.scratch.get_str_lifestyle(), + role.scratch.get_str_firstname(), + ] + return prompt_input + + prompt_input = create_prompt_input(role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "wake_up_hour_v1.txt") + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=5) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/memory/__init__.py b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/memory/agent_memory.py b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/agent_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..d212232f42c560c90d264eb2e2ebc63d6674b2a6 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/agent_memory.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : BasicMemory,AgentMemory实现 + +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import Field, field_serializer, model_validator + +from metagpt.logs import logger +from metagpt.memory.memory import Memory +from metagpt.schema import Message +from metagpt.utils.common import read_json_file, write_json_file + + +class BasicMemory(Message): + """ + BasicMemory继承于MG的Message类,其中content属性替代description属性 + Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 + 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) + """ + + memory_id: Optional[str] = Field(default=None) # 记忆ID + memory_count: int = -1 # 第几个记忆,实际数值与Memory相等 + type_count: int = -1 # 第几种记忆,类型为整数 + memory_type: Optional[str] = Field(default=None) # 记忆类型,包含 event,thought,chat三种类型 + depth: int = -1 # 记忆深度,类型为整数 + created: Optional[datetime] = Field(default=None) # 创建时间 + expiration: Optional[datetime] = Field(default=None) # 记忆失效时间,默认为空() + last_accessed: Optional[datetime] = Field(default=None) # 上一次调用的时间,初始化时候与self.created一致 + subject: Optional[str] = Field(default=None) # 主语 + predicate: Optional[str] = Field(default=None) # 谓语 + object: Optional[str] = Field(default=None) # 宾语 + + description: Optional[str] = Field(default=None) + embedding_key: Optional[str] = Field(default=None) # 内容与self.content一致 + poignancy: int = -1 # importance值 + keywords: list[str] = Field(default=[]) # keywords + filling: list = Field(default=[]) # 装的与之相关联的memory_id的列表 + + __hash__ = object.__hash__ # support hash in AgentMemory + + @model_validator(mode="before") + @classmethod + def check_values(cls, values): + if "created" in values: + values["last_accessed"] = values["created"] + if "content" in values: + values["description"] = values["content"] + if "filling" in values: + values["filling"] = values["filling"] or [] + return values + + @field_serializer("created", "expiration") + def transform_time_field(self, time_field: Optional[datetime]) -> str: + if time_field: + time_field = time_field.strftime("%Y-%m-%d %H:%M:%S") + return time_field + + def summary(self): + return self.subject, self.predicate, self.object + + def save_to_dict(self) -> dict: + """ + 将MemoryBasic类转化为字典,用于存储json文件 + 这里需要注意,cause_by跟GA不兼容,所以需要做一个格式转换 + """ + memory_dict = dict() + node_id = self.memory_id + basic_mem_obj = self.model_dump( + include=[ + "node_count", + "type_count", + "type", + "depth", + "created", + "expiration", + "subject", + "predicate", + "object", + "description", + "embedding_key", + "poignancy", + "keywords", + "filling", + "cause_by", + ] + ) + + memory_dict[node_id] = basic_mem_obj + return memory_dict + + +class AgentMemory(Memory): + """ + GA中主要存储三种JSON + 1. embedding.json (Dict embedding_key:embedding) + 2. Node.json (Dict Node_id:Node) + 3. kw_strength.json + """ + + storage: list[BasicMemory] = [] # 重写Storage,存储BasicMemory所有节点 + event_list: list[BasicMemory] = [] # 存储event记忆 + thought_list: list[BasicMemory] = [] # 存储thought记忆 + chat_list: list[BasicMemory] = [] # chat-related memory + + event_keywords: dict[str, list[BasicMemory]] = dict() # 存储keywords + thought_keywords: dict[str, list[BasicMemory]] = dict() + chat_keywords: dict[str, list[BasicMemory]] = dict() + + kw_strength_event: dict[str, int] = dict() + kw_strength_thought: dict[str, int] = dict() + + memory_saved: Optional[Path] = Field(default=None) + embeddings: dict[str, list[float]] = dict() + + def set_mem_path(self, memory_saved: Path): + self.memory_saved = memory_saved + self.load(memory_saved) + + def save(self, memory_saved: Path): + """ + 将MemoryBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 + 这里添加一个路径即可 + TODO 这里在存储时候进行倒序存储,之后需要验证(test_memory通过) + """ + memory_json = dict() + for i in range(len(self.storage)): + memory_node = self.storage[len(self.storage) - i - 1] + memory_node = memory_node.save_to_dict() + memory_json.update(memory_node) + write_json_file(memory_saved.joinpath("nodes.json"), memory_json) + write_json_file(memory_saved.joinpath("embeddings.json"), self.embeddings) + + strength_json = dict() + strength_json["kw_strength_event"] = self.kw_strength_event + strength_json["kw_strength_thought"] = self.kw_strength_thought + write_json_file(memory_saved.joinpath("kw_strength.json"), strength_json) + + def load(self, memory_saved: Path): + """ + 将GA的JSON解析,填充到AgentMemory类之中 + """ + self.embeddings = read_json_file(memory_saved.joinpath("embeddings.json")) + memory_load = read_json_file(memory_saved.joinpath("nodes.json")) + for count in range(len(memory_load.keys())): + node_id = f"node_{str(count + 1)}" + node_details = memory_load[node_id] + node_type = node_details["type"] + created = datetime.strptime(node_details["created"], "%Y-%m-%d %H:%M:%S") + expiration = None + if node_details["expiration"]: + expiration = datetime.strptime(node_details["expiration"], "%Y-%m-%d %H:%M:%S") + + s = node_details["subject"] + p = node_details["predicate"] + o = node_details["object"] + + description = node_details["description"] + embedding_pair = (node_details["embedding_key"], self.embeddings[node_details["embedding_key"]]) + poignancy = node_details["poignancy"] + keywords = set(node_details["keywords"]) + filling = node_details["filling"] + if node_type == "thought": + self.add_thought( + created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling + ) + if node_type == "event": + self.add_event(created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling) + if node_type == "chat": + self.add_chat(created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling) + + strength_keywords_load = read_json_file(memory_saved.joinpath("kw_strength.json")) + if strength_keywords_load["kw_strength_event"]: + self.kw_strength_event = strength_keywords_load["kw_strength_event"] + if strength_keywords_load["kw_strength_thought"]: + self.kw_strength_thought = strength_keywords_load["kw_strength_thought"] + + def add(self, memory_basic: BasicMemory): + """ + Add a new message to storage, while updating the index + 重写add方法,修改原有的Message类为BasicMemory类,并添加不同的记忆类型添加方式 + """ + if memory_basic.memory_id in self.storage: + return + self.storage.append(memory_basic) + if memory_basic.memory_type == "chat": + self.chat_list[0:0] = [memory_basic] + return + if memory_basic.memory_type == "thought": + self.thought_list[0:0] = [memory_basic] + return + if memory_basic.memory_type == "event": + self.event_list[0:0] = [memory_basic] + return + + def add_chat( + self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling, cause_by="" + ): + """ + 调用add方法,初始化chat,在创建的时候就需要调用embedding函数 + """ + memory_count = len(self.storage) + 1 + type_count = len(self.thought_list) + 1 + memory_type = "chat" + memory_id = f"node_{str(memory_count)}" + depth = 1 + + memory_node = BasicMemory( + memory_id=memory_id, + memory_count=memory_count, + type_count=type_count, + memory_type=memory_type, + depth=depth, + created=created, + expiration=expiration, + subject=s, + predicate=p, + object=o, + description=content, + embedding_key=embedding_pair[0], + poignancy=poignancy, + keywords=keywords, + filling=filling, + cause_by=cause_by, + ) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.chat_keywords: + self.chat_keywords[kw][0:0] = [memory_node] + else: + self.chat_keywords[kw] = [memory_node] + + self.add(memory_node) + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + def add_thought(self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling): + """ + 调用add方法,初始化thought + """ + memory_count = len(self.storage) + 1 + type_count = len(self.thought_list) + 1 + memory_type = "thought" + memory_id = f"node_{str(memory_count)}" + depth = 1 + + try: + if filling: + depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling] + depth += max(depth_list) + except Exception as exp: + logger.warning(f"filling init occur {exp}") + pass + + memory_node = BasicMemory( + memory_id=memory_id, + memory_count=memory_count, + type_count=type_count, + memory_type=memory_type, + depth=depth, + created=created, + expiration=expiration, + subject=s, + predicate=p, + object=o, + description=content, + embedding_key=embedding_pair[0], + poignancy=poignancy, + keywords=keywords, + filling=filling, + ) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.thought_keywords: + self.thought_keywords[kw][0:0] = [memory_node] + else: + self.thought_keywords[kw] = [memory_node] + + self.add(memory_node) + + if f"{p} {o}" != "is idle": + for kw in keywords: + if kw in self.kw_strength_thought: + self.kw_strength_thought[kw] += 1 + else: + self.kw_strength_thought[kw] = 1 + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + def add_event(self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling): + """ + 调用add方法,初始化event + """ + memory_count = len(self.storage) + 1 + type_count = len(self.event_list) + 1 + memory_type = "event" + memory_id = f"node_{str(memory_count)}" + depth = 0 + + if "(" in content: + content = " ".join(content.split()[:3]) + " " + content.split("(")[-1][:-1] + + memory_node = BasicMemory( + memory_id=memory_id, + memory_count=memory_count, + type_count=type_count, + memory_type=memory_type, + depth=depth, + created=created, + expiration=expiration, + subject=s, + predicate=p, + object=o, + description=content, + embedding_key=embedding_pair[0], + poignancy=poignancy, + keywords=keywords, + filling=filling, + ) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.event_keywords: + self.event_keywords[kw][0:0] = [memory_node] + else: + self.event_keywords[kw] = [memory_node] + + self.add(memory_node) + + if f"{p} {o}" != "is idle": + for kw in keywords: + if kw in self.kw_strength_event: + self.kw_strength_event[kw] += 1 + else: + self.kw_strength_event[kw] = 1 + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + def get_summarized_latest_events(self, retention): + ret_set = set() + for e_node in self.event_list[:retention]: + ret_set.add(e_node.summary()) + return ret_set + + def get_last_chat(self, target_role_name: str): + if target_role_name.lower() in self.chat_keywords: + return self.chat_keywords[target_role_name.lower()][0] + else: + return False + + def retrieve_relevant_thoughts(self, s_content: str, p_content: str, o_content: str) -> set: + contents = [s_content, p_content, o_content] + + ret = [] + for i in contents: + if i in self.thought_keywords: + ret += self.thought_keywords[i.lower()] + + ret = set(ret) + return ret + + def retrieve_relevant_events(self, s_content: str, p_content: str, o_content: str) -> set: + contents = [s_content, p_content, o_content] + + ret = [] + for i in contents: + if i in self.event_keywords: + ret += self.event_keywords[i] + + ret = set(ret) + return ret diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/memory/retrieve.py b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/retrieve.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b32f965037f2ba1b931dd8aca77ae56d0a9b36 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/retrieve.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Retrieve函数实现 + +import datetime + +from numpy import dot +from numpy.linalg import norm + +from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory +from metagpt.ext.stanford_town.utils.utils import get_embedding + + +def agent_retrieve( + agent_memory, + curr_time: datetime.datetime, + memory_forget: float, + query: str, + nodes: list[BasicMemory], + topk: int = 4, +) -> list[BasicMemory]: + """ + Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch + 逻辑:Role调用该函数,self.rc.AgentMemory,self.rc.scratch.curr_time,self.rc.scratch.memory_forget + 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[BasicMemory] + + Score_lists示例 + { + "memory": memories[i], BasicMemory类 + "importance": memories[i].poignancy + "recency": 衰减因子计算结果 + "relevance": 搜索结果 + } + """ + memories = nodes + agent_memory_embedding = agent_memory.embeddings + memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed, reverse=True) + + score_list = [] + score_list = extract_importance(memories, score_list) + score_list = extract_recency(curr_time, memory_forget, score_list) + score_list = extract_relevance(agent_memory_embedding, query, score_list) + score_list = normalize_score_floats(score_list, 0, 1) + + total_dict = {} + gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性, + for i in range(len(score_list)): + total_score = ( + score_list[i]["importance"] * gw[0] + score_list[i]["recency"] * gw[1] + score_list[i]["relevance"] * gw[2] + ) + total_dict[score_list[i]["memory"].memory_id] = total_score + + result = top_highest_x_values(total_dict, topk) + + return result # 返回的是一个BasicMemory列表 + + +def new_agent_retrieve(role, focus_points: list, n_count=30) -> dict: + """ + 输入为role,关注点列表,返回记忆数量 + 输出为字典,键为focus_point,值为对应的记忆列表 + """ + retrieved = dict() + for focal_pt in focus_points: + nodes = [ + [i.last_accessed, i] + for i in role.memory.event_list + role.memory.thought_list + if "idle" not in i.embedding_key + ] + nodes = sorted(nodes, key=lambda x: x[0]) + nodes = [i for created, i in nodes] + results = agent_retrieve( + role.memory, role.scratch.curr_time, role.scratch.recency_decay, focal_pt, nodes, n_count + ) + final_result = [] + for n in results: + for i in role.memory.storage: + if i.memory_id == n: + i.last_accessed = role.scratch.curr_time + final_result.append(i) + + retrieved[focal_pt] = final_result + + return retrieved + + +def top_highest_x_values(d, x): + """ + 输入字典,Topx + 返回以字典值排序,字典键组成的List[BasicMemory] + """ + top_v = [item[0] for item in sorted(d.items(), key=lambda item: item[1], reverse=True)[:x]] + return top_v + + +def extract_importance(memories, score_list): + """ + 抽取重要性 + """ + for i in range(len(memories)): + score = {"memory": memories[i], "importance": memories[i].poignancy} + score_list.append(score) + return score_list + + +def extract_relevance(agent_memory_embedding, query, score_list): + """ + 抽取相关性 + """ + query_embedding = get_embedding(query) + # 进行 + for i in range(len(score_list)): + node_embedding = agent_memory_embedding[score_list[i]["memory"].embedding_key] + result = cos_sim(node_embedding, query_embedding) + score_list[i]["relevance"] = result + + return score_list + + +def extract_recency(curr_time, memory_forget, score_list): + """ + 抽取近因性,目前使用的现实世界过一天走一个衰减因子 + """ + for i in range(len(score_list)): + day_count = (curr_time - score_list[i]["memory"].created).days + score_list[i]["recency"] = memory_forget**day_count + return score_list + + +def cos_sim(a, b): + """ + 计算余弦相似度 + """ + return dot(a, b) / (norm(a) * norm(b)) + + +def normalize_list_floats(single_list, target_min, target_max): + """ + 单个列表归一化 + """ + if len(single_list) == 0: + return [] + + min_val = min(single_list) + max_val = max(single_list) + range_val = max_val - min_val + + if range_val == 0: + for i in range(len(single_list)): + single_list[i] = (target_max - target_min) / 2 + else: + for i in range(len(single_list)): + single_list[i] = (single_list[i] - min_val) * (target_max - target_min) / range_val + target_min + return single_list + + +def normalize_score_floats(score_list, target_min, target_max): + """ + 整体归一化 + """ + importance_list = [] + relevance_list = [] + recency_list = [] + + for i in range(len(score_list)): + importance_list.append(score_list[i]["importance"]) + relevance_list.append(score_list[i]["relevance"]) + recency_list.append(score_list[i]["recency"]) + + # 进行归一化操作 + importance_list = normalize_list_floats(importance_list, target_min, target_max) + relevance_list = normalize_list_floats(relevance_list, target_min, target_max) + recency_list = normalize_list_floats(recency_list, target_min, target_max) + + for i in range(len(score_list)): + score_list[i]["importance"] = importance_list[i] + score_list[i]["relevance"] = relevance_list[i] + score_list[i]["recency"] = recency_list[i] + + return score_list diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/memory/scratch.py b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/scratch.py new file mode 100644 index 0000000000000000000000000000000000000000..b4036f839fb555ef2302345a8065bf38ce7c4494 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/scratch.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Scratch类实现(角色信息类) + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, Field, field_serializer, field_validator + +from metagpt.utils.common import read_json_file, write_json_file + + +class Scratch(BaseModel): + # 类别1:人物超参 + vision_r: int = 4 + att_bandwidth: int = 3 + retention: int = 5 + + # 类别2:世界信息 + curr_time: Optional[datetime] = Field(default=None) + curr_tile: Optional[list[int]] = Field(default=None) + daily_plan_req: Optional[str] = Field(default=None) + + # 类别3:人物角色的核心身份 + name: Optional[str] = Field(default=None) + first_name: Optional[str] = Field(default=None) + last_name: Optional[str] = Field(default=None) + age: Optional[int] = Field(default=None) + innate: Optional[str] = Field(default=None) # L0 permanent core traits. + learned: Optional[str] = Field(default=None) # L1 stable traits. + currently: Optional[str] = Field(default=None) # L2 external implementation. + lifestyle: Optional[str] = Field(default=None) + living_area: Optional[str] = Field(default=None) + + # 类别4:旧反思变量 + concept_forget: int = 100 + daily_reflection_time: int = 60 * 3 + daily_reflection_size: int = 5 + overlap_reflect_th: int = 2 + kw_strg_event_reflect_th: int = 4 + kw_strg_thought_reflect_th: int = 4 + + # 类别5:新反思变量 + recency_w: int = 1 + relevance_w: int = 1 + importance_w: int = 1 + recency_decay: float = 0.99 + importance_trigger_max: int = 150 + importance_trigger_curr: int = 150 + importance_ele_n: int = 0 + thought_count: int = 5 + + # 类别6:个人计划 + daily_req: list[str] = Field(default=[]) + f_daily_schedule: list[list[Union[int, str]]] = Field(default=[]) + f_daily_schedule_hourly_org: list[list[Union[int, str]]] = Field(default=[]) + + # 类别7:当前动作 + act_address: Optional[str] = Field(default=None) + act_start_time: Optional[datetime] = Field(default=None) + act_duration: Optional[int] = Field(default=None) + act_description: Optional[str] = Field(default=None) + act_pronunciatio: Optional[str] = Field(default=None) + act_event: list[Optional[str]] = [None, None, None] + + act_obj_description: Optional[str] = Field(default=None) + act_obj_pronunciatio: Optional[str] = Field(default=None) + act_obj_event: list[Optional[str]] = [None, None, None] + + chatting_with: Optional[str] = Field(default=None) + chat: Optional[str] = Field(default=None) + chatting_with_buffer: dict = dict() + chatting_end_time: Optional[datetime] = Field(default=None) + + act_path_set: bool = False + planned_path: list[list[int]] = Field(default=[]) + + @field_validator("curr_time", "act_start_time", "chatting_end_time", mode="before") + @classmethod + def check_time_filed(cls, time_filed): + val = datetime.strptime(time_filed, "%B %d, %Y, %H:%M:%S") if time_filed else None + return val + + @field_serializer("curr_time", "act_start_time", "chatting_end_time") + def transform_time_field(self, time_filed: Optional[datetime]) -> str: + if time_filed: + time_filed = time_filed.strftime("%B %d, %Y, %H:%M:%S") + return time_filed + + @classmethod + def init_scratch_from_path(cls, f_saved: Path): + scratch_load = read_json_file(f_saved) + scratch = Scratch(**scratch_load) + return scratch + + def save(self, out_json: Path): + """ + Save persona's scratch. + + INPUT: + out_json: The file where we wil be saving our persona's state. + OUTPUT: + None + """ + scratch = self.model_dump() + write_json_file(out_json, scratch, encoding="utf-8") + + def get_f_daily_schedule_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule. + + Recall that self.f_daily_schedule stores the decomposed action sequences + up until now, and the hourly sequences of the future action for the rest + of today. Given that self.f_daily_schedule is a list of list where the + inner list is composed of [task, duration], we continue to add up the + duration until we reach "if elapsed > today_min_elapsed" condition. The + index where we stop is the index we will return. + + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + + x = 0 + for task, duration in self.f_daily_schedule: + x += duration + x = 0 + for task, duration in self.f_daily_schedule_hourly_org: + x += duration + + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 + + return curr_index + + def get_f_daily_schedule_hourly_org_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule_hourly_org. + It is otherwise the same as get_f_daily_schedule_index. + + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule_hourly_org: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 + return curr_index + + def get_str_iss(self): + """ + ISS stands for "identity stable set." This describes the commonset summary + of this persona -- basically, the bare minimum description of the persona + that gets used in almost all prompts that need to call on the persona. + + INPUT + None + OUTPUT + the identity stable set summary of the persona in a string form. + EXAMPLE STR OUTPUT + "Name: Dolores Heitmiller + Age: 28 + Innate traits: hard-edged, independent, loyal + Learned traits: Dolores is a painter who wants live quietly and paint + while enjoying her everyday life. + Currently: Dolores is preparing for her first solo show. She mostly + works from home. + Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats + dinner around 6pm. + Daily plan requirement: Dolores is planning to stay at home all day and + never go out." + """ + commonset = "" + commonset += f"Name: {self.name}\n" + commonset += f"Age: {self.age}\n" + commonset += f"Innate traits: {self.innate}\n" + commonset += f"Learned traits: {self.learned}\n" + commonset += f"Currently: {self.currently}\n" + commonset += f"Lifestyle: {self.lifestyle}\n" + commonset += f"Daily plan requirement: {self.daily_plan_req}\n" + commonset += f"Current Date: {self.curr_time.strftime('%A %B %d') if self.curr_time else ''}\n" + return commonset + + def get_str_name(self): + return self.name + + def get_str_firstname(self): + return self.first_name + + def get_str_lastname(self): + return self.last_name + + def get_str_age(self): + return str(self.age) + + def get_str_innate(self): + return self.innate + + def get_str_learned(self): + return self.learned + + def get_str_currently(self): + return self.currently + + def get_str_lifestyle(self): + return self.lifestyle + + def get_str_daily_plan_req(self): + return self.daily_plan_req + + def get_str_curr_date_str(self): + return self.curr_time.strftime("%A %B %d") + + def get_curr_event(self): + if not self.act_address: + return self.name, None, None + else: + return self.act_event + + def get_curr_event_and_desc(self): + if not self.act_address: + return self.name, None, None, None + else: + return self.act_event[0], self.act_event[1], self.act_event[2], self.act_description + + def get_curr_obj_event_and_desc(self): + if not self.act_address: + return "", None, None, None + else: + return self.act_address, self.act_obj_event[1], self.act_obj_event[2], self.act_obj_description + + def add_new_action( + self, + action_address, + action_duration, + action_description, + action_pronunciatio, + action_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time=None, + ): + self.act_address = action_address + self.act_duration = action_duration + self.act_description = action_description + self.act_pronunciatio = action_pronunciatio + self.act_event = action_event + + self.chatting_with = chatting_with + self.chat = chat + if chatting_with_buffer: + self.chatting_with_buffer.update(chatting_with_buffer) + self.chatting_end_time = chatting_end_time + + self.act_obj_description = act_obj_description + self.act_obj_pronunciatio = act_obj_pronunciatio + self.act_obj_event = act_obj_event + + self.act_start_time = self.curr_time + + self.act_path_set = False + + def act_time_str(self): + """ + Returns a string output of the current time. + + INPUT + None + OUTPUT + A string output of the current time. + EXAMPLE STR OUTPUT + "14:05 P.M." + """ + return self.act_start_time.strftime("%H:%M %p") + + def act_check_finished(self): + """ + Checks whether the self.Action instance has finished. + + INPUT + curr_datetime: Current time. If current time is later than the action's + start time + its duration, then the action has finished. + OUTPUT + Boolean [True]: Action has finished. + Boolean [False]: Action has not finished and is still ongoing. + """ + if not self.act_address: + return True + + if self.chatting_with: + end_time = self.chatting_end_time + else: + x = self.act_start_time + if x.second != 0: + x = x.replace(second=0) + x = x + timedelta(minutes=1) + end_time = x + timedelta(minutes=self.act_duration) + + if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"): + return True + return False + + def act_summarize(self): + """ + Summarize the current action as a dictionary. + + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + exp = dict() + exp["persona"] = self.name + exp["address"] = self.act_address + exp["start_datetime"] = self.act_start_time + exp["duration"] = self.act_duration + exp["description"] = self.act_description + exp["pronunciatio"] = self.act_pronunciatio + return exp + + def act_summary_str(self): + """ + Returns a string summary of the current action. Meant to be + human-readable. + + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p") + ret = f"[{start_datetime_str}]\n" + ret += f"Activity: {self.name} is {self.act_description}\n" + ret += f"Address: {self.act_address}\n" + ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n" + return ret + + def get_daily_schedule(self, daily_schedule: list[list[str]]): + ret = "" + curr_min_sum = 0 + for row in daily_schedule: + curr_min_sum += row[1] + hour = int(curr_min_sum / 60) + minute = curr_min_sum % 60 + ret += f"{hour:02}:{minute:02} || {row[0]}\n" + return ret + + def get_str_daily_schedule_summary(self): + return self.get_daily_schedule(self.f_daily_schedule) + + def get_str_daily_schedule_hourly_org_summary(self): + return self.get_daily_schedule(self.f_daily_schedule_hourly_org) diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/memory/spatial_memory.py b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/spatial_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..71b8569079c9663ff8f3bb4944b766ebf47367f2 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/memory/spatial_memory.py @@ -0,0 +1,116 @@ +""" +Author: Joon Sung Park (joonspk@stanford.edu) + +File: spatial_memory.py +Description: Defines the MemoryTree class that serves as the agents' spatial +memory that aids in grounding their behavior in the game world. +""" +from pathlib import Path + +from pydantic import BaseModel, Field + +from metagpt.logs import logger +from metagpt.utils.common import read_json_file, write_json_file + + +class MemoryTree(BaseModel): + tree: dict = Field(default=dict) + + def set_mem_path(self, f_saved: Path): + self.tree = read_json_file(f_saved) + + def print_tree(self) -> None: + def _print_tree(tree, depth): + dash = " >" * depth + if isinstance(tree, list): + if tree: + logger.info(f"{dash} {tree}") + return + + for key, val in tree.items(): + if key: + logger.info(f"{dash} {tree}") + _print_tree(val, depth + 1) + + _print_tree(self.tree, 0) + + def save(self, out_json: Path) -> None: + write_json_file(out_json, self.tree) + + def get_str_accessible_sectors(self, curr_world: str) -> str: + """ + Returns a summary string of all the arenas that the persona can access + within the current sector. + + Note that there are places a given persona cannot enter. This information + is provided in the persona sheet. We account for this in this function. + + INPUT + None + OUTPUT + A summary string of all the arenas that the persona can access. + EXAMPLE STR OUTPUT + "bedroom, kitchen, dining room, office, bathroom" + """ + x = ", ".join(list(self.tree[curr_world].keys())) + return x + + def get_str_accessible_sector_arenas(self, sector: str) -> str: + """ + Returns a summary string of all the arenas that the persona can access + within the current sector. + + Note that there are places a given persona cannot enter. This information + is provided in the persona sheet. We account for this in this function. + + INPUT + None + OUTPUT + A summary string of all the arenas that the persona can access. + EXAMPLE STR OUTPUT + "bedroom, kitchen, dining room, office, bathroom" + """ + curr_world, curr_sector = sector.split(":") + if not curr_sector: + return "" + x = ", ".join(list(self.tree[curr_world][curr_sector].keys())) + return x + + def get_str_accessible_arena_game_objects(self, arena: str) -> str: + """ + Get a str list of all accessible game objects that are in the arena. If + temp_address is specified, we return the objects that are available in + that arena, and if not, we return the objects that are in the arena our + persona is currently in. + + INPUT + temp_address: optional arena address + OUTPUT + str list of all accessible game objects in the gmae arena. + EXAMPLE STR OUTPUT + "phone, charger, bed, nightstand" + """ + curr_world, curr_sector, curr_arena = arena.split(":") + + if not curr_arena: + return "" + + try: + x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena])) + except Exception: + x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena.lower()])) + return x + + def add_tile_info(self, tile_info: dict) -> None: + if tile_info["world"]: + if tile_info["world"] not in self.tree: + self.tree[tile_info["world"]] = {} + if tile_info["sector"]: + if tile_info["sector"] not in self.tree[tile_info["world"]]: + self.tree[tile_info["world"]][tile_info["sector"]] = {} + if tile_info["arena"]: + if tile_info["arena"] not in self.tree[tile_info["world"]][tile_info["sector"]]: + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = [] + if tile_info["game_object"]: + if tile_info["game_object"] not in self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]]: + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [tile_info["game_object"]] diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/plan/__init__.py b/notebook_dir/metagpt_yusin/ext/stanford_town/plan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/plan/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/plan/converse.py b/notebook_dir/metagpt_yusin/ext/stanford_town/plan/converse.py new file mode 100644 index 0000000000000000000000000000000000000000..8eefbc9b42b4e0bd5f359f61e924bd4a455e0127 --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/plan/converse.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : conversation between two agents + +from typing import Tuple + +from metagpt.ext.stanford_town.actions.agent_chat_sum_rel import AgentChatSumRel +from metagpt.ext.stanford_town.actions.gen_iter_chat_utt import GenIterChatUTT +from metagpt.ext.stanford_town.memory.retrieve import new_agent_retrieve +from metagpt.logs import logger + + +async def agent_conversation(init_role: "STRole", target_role: "STRole", conv_rounds: int = 8) -> list[list[str]]: + curr_chat = [] + logger.info(f"Role: {init_role.name} starts a conversation with Role: {target_role.name}") + + for idx in range(conv_rounds): + logger.info(f"Conv round: {idx} between {init_role.name} and {target_role.name}") + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + + focal_points = [f"{target_scratch.name}"] + retrieved = new_agent_retrieve(init_role, focal_points, 50) + relationship = await generate_summarize_agent_relationship(init_role, target_role, retrieved) + logger.info(f"The relationship between {init_role.name} and {target_role.name}: {relationship}") + last_chat = "" + for i in curr_chat[-4:]: + last_chat += ": ".join(i) + "\n" + if last_chat: + focal_points = [f"{relationship}", f"{target_scratch.name} is {target_scratch.act_description}", last_chat] + else: + focal_points = [f"{relationship}", f"{target_scratch.name} is {target_scratch.act_description}"] + retrieved = new_agent_retrieve(init_role, focal_points, 15) + utt, end = await generate_one_utterance(init_role, target_role, retrieved, curr_chat) + + curr_chat += [[scratch.name, utt]] + if end: + break + + focal_points = [f"{scratch.name}"] + retrieved = new_agent_retrieve(target_role, focal_points, 50) + relationship = await generate_summarize_agent_relationship(target_role, init_role, retrieved) + logger.info(f"The relationship between {target_role.name} and {init_role.name}: {relationship}") + last_chat = "" + for i in curr_chat[-4:]: + last_chat += ": ".join(i) + "\n" + if last_chat: + focal_points = [f"{relationship}", f"{scratch.name} is {scratch.act_description}", last_chat] + else: + focal_points = [f"{relationship}", f"{scratch.name} is {scratch.act_description}"] + retrieved = new_agent_retrieve(target_role, focal_points, 15) + utt, end = await generate_one_utterance(target_role, init_role, retrieved, curr_chat) + + curr_chat += [[target_scratch.name, utt]] + if end: + break + + logger.warning(f"Conversations between {target_role.name} and {init_role.name}:") + for row in curr_chat: + logger.info(row) + + return curr_chat + + +async def generate_summarize_agent_relationship(init_role: "STRole", target_role: "STRole", retrieved: dict) -> str: + all_embedding_keys = list() + for key, val in retrieved.items(): + for i in val: + all_embedding_keys += [i.embedding_key] + all_embedding_key_str = "" + for i in all_embedding_keys: + all_embedding_key_str += f"{i}\n" + + summarized_relationship = await AgentChatSumRel().run(init_role, target_role, all_embedding_key_str) + return summarized_relationship + + +async def generate_one_utterance(init_role, target_role, retrieved: dict, curr_chat: list) -> Tuple[str, str]: + # Chat version optimized for speed via batch generation + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + curr_context = ( + f"{scratch.name} " + + f"was {scratch.act_description} " + + f"when {scratch.name} " + + f"saw {target_scratch.name} " + + f"in the middle of {target_scratch.act_description}.\n" + ) + curr_context += f"{scratch.name} " + "is initiating a conversation with " + f"{target_scratch.name}." + + x = await GenIterChatUTT().run(init_role, target_role, retrieved, curr_context, curr_chat) + + return x["utterance"], x["end"] diff --git a/notebook_dir/metagpt_yusin/ext/stanford_town/plan/st_plan.py b/notebook_dir/metagpt_yusin/ext/stanford_town/plan/st_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..f63052fc5324f06b67d8426f687046852d76952d --- /dev/null +++ b/notebook_dir/metagpt_yusin/ext/stanford_town/plan/st_plan.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : st' planning execution + +import datetime +import math +import random +from typing import Tuple, Union + +from metagpt.ext.stanford_town.actions.decide_to_talk import DecideToTalk +from metagpt.ext.stanford_town.actions.gen_action_details import GenActionDetails +from metagpt.ext.stanford_town.actions.gen_daily_schedule import GenDailySchedule +from metagpt.ext.stanford_town.actions.gen_hourly_schedule import GenHourlySchedule +from metagpt.ext.stanford_town.actions.new_decomp_schedule import NewDecompSchedule +from metagpt.ext.stanford_town.actions.summarize_conv import SummarizeConv +from metagpt.ext.stanford_town.actions.task_decomp import TaskDecomp +from metagpt.ext.stanford_town.actions.wake_up import WakeUp +from metagpt.ext.stanford_town.memory.retrieve import new_agent_retrieve +from metagpt.ext.stanford_town.plan.converse import agent_conversation +from metagpt.ext.stanford_town.utils.utils import get_embedding +from metagpt.llm import LLM +from metagpt.logs import logger + + +async def plan(role: "STRole", roles: dict["STRole"], new_day: bool, retrieved: dict) -> str: + # PART 1: Generate the hourly schedule. + if new_day: + await _long_term_planning(role, new_day) + + # PART 2: If the current action has expired, we want to create a new plan. + act_check_finished = role.scratch.act_check_finished() + logger.info(f"Role: {role.name} act_check_finished is {act_check_finished}") + if act_check_finished: + await _determine_action(role) + + # PART 3: If you perceived an event that needs to be responded to (saw + # another role), and retrieved relevant information. + # Step 1: Retrieved may have multiple events represented in it. The first + # job here is to determine which of the events we want to focus + # on for the role. + # takes the form of a dictionary like this: + # dictionary {["curr_event"] = , + # ["events"] = [, ...], + # ["thoughts"] = [, ...]} + focused_event = False + if retrieved.keys(): + focused_event = _choose_retrieved(role.name, retrieved) + + # Step 2: Once we choose an event, we need to determine whether the + # role will take any actions for the perceived event. There are + # three possible modes of reaction returned by _should_react. + # a) "chat with {target_role.name}" + # b) "react" + # c) False + logger.info(f"Role: {role.name} focused_event: {focused_event}") + if focused_event: + reaction_mode = await _should_react(role, focused_event, roles) + logger.info(f"Role: {role.name} reaction_mode: {reaction_mode}") + if reaction_mode: + # If we do want to chat, then we generate conversation + if reaction_mode[:9] == "chat with": + await _chat_react(role, reaction_mode, roles) + elif reaction_mode[:4] == "wait": + await _wait_react(role, reaction_mode) + + # Step 3: Chat-related state clean up. + # If the persona is not chatting with anyone, we clean up any of the + # chat-related states here. + if role.rc.scratch.act_event[1] != "chat with": + role.rc.scratch.chatting_with = None + role.rc.scratch.chat = None + role.rc.scratch.chatting_end_time = None + # We want to make sure that the persona does not keep conversing with each + # other in an infinite loop. So, chatting_with_buffer maintains a form of + # buffer that makes the persona wait from talking to the same target + # immediately after chatting once. We keep track of the buffer value here. + curr_persona_chat_buffer = role.rc.scratch.chatting_with_buffer + for persona_name, buffer_count in curr_persona_chat_buffer.items(): + if persona_name != role.rc.scratch.chatting_with: + role.rc.scratch.chatting_with_buffer[persona_name] -= 1 + + return role.rc.scratch.act_address + + +def _choose_retrieved(role_name: str, retrieved: dict) -> Union[None, dict]: + """ + Retrieved elements have multiple core "curr_events". We need to choose one + event to which we are going to react to. We pick that event here. + Args: + role_name: Current role instance's name whose action we are determining. + retrieved: A dictionary of that were retrieved from the + the role's associative memory. This dictionary takes the + following form: + dictionary[event.description] = + {["curr_event"] = , + ["events"] = [, ...], + ["thoughts"] = [, ...] } + """ + # Once we are done with the reflection, we might want to build a more + # complex structure here. + + # We do not want to take self events... for now + copy_retrieved = retrieved.copy() + for event_desc, rel_ctx in copy_retrieved.items(): + curr_event = rel_ctx["curr_event"] + if curr_event.subject == role_name: + del retrieved[event_desc] + + # Always choose role first. + priority = [] + for event_desc, rel_ctx in retrieved.items(): + curr_event = rel_ctx["curr_event"] + if ":" not in curr_event.subject and curr_event.subject != role_name: + priority += [rel_ctx] + if priority: + return random.choice(priority) + + # Skip idle. + for event_desc, rel_ctx in retrieved.items(): + if "is idle" not in event_desc: + priority += [rel_ctx] + if priority: + return random.choice(priority) + return None + + +async def _should_react(role: "STRole", retrieved: dict, roles: dict): + """ + Determines what form of reaction the role should exihibit given the + retrieved values. + INPUT + role: Current <"STRole"> instance whose action we are determining. + retrieved: A dictionary of that were retrieved from the + the role's associative memory. This dictionary takes the + following form: + dictionary[event.description] = + {["curr_event"] = , + ["events"] = [, ...], + ["thoughts"] = [, ...] } + roles: A dictionary that contains all role names as keys, and the + <"STRole"> instance as values. + """ + + async def lets_talk(init_role: "STRole", target_role: "STRole", retrieved: dict): + if init_role.name == target_role.name: + logger.info(f"Role: {role.name} _should_react lets_talk meet same role, return False") + return False + + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + if ( + not target_scratch.act_address + or not target_scratch.act_description + or not scratch.act_address + or not scratch.act_description + ): + return False + + if "sleeping" in target_scratch.act_description or "sleeping" in scratch.act_description: + return False + + if scratch.curr_time.hour == 23: + return False + + if "" in target_scratch.act_address: + return False + + if target_scratch.chatting_with or scratch.chatting_with: + return False + + if target_role.name in scratch.chatting_with_buffer: + if scratch.chatting_with_buffer[target_role.name] > 0: + return False + + if await DecideToTalk().run(init_role, target_role, retrieved): + return True + + return False + + async def lets_react(init_role: "STRole", target_role: "STRole", retrieved: dict): + if init_role.name == target_role.name: + logger.info(f"Role: {role.name} _should_react lets_react meet same role, return False") + return False + + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + if ( + not target_scratch.act_address + or not target_scratch.act_description + or not scratch.act_address + or not scratch.act_description + ): + return False + + if "sleeping" in target_scratch.act_description or "sleeping" in scratch.act_description: + return False + + # return False + if scratch.curr_time.hour == 23: + return False + + if "waiting" in target_scratch.act_description: + return False + if scratch.planned_path == []: + return False + + if scratch.act_address != target_scratch.act_address: + return False + + react_mode = await DecideToTalk().run(init_role, target_role, retrieved) + + if react_mode == "1": + wait_until = ( + target_scratch.act_start_time + datetime.timedelta(minutes=target_scratch.act_duration - 1) + ).strftime("%B %d, %Y, %H:%M:%S") + return f"wait: {wait_until}" + elif react_mode == "2": + return False + return "do other things" + else: + return False # "keep" + + # If the role is chatting right now, default to no reaction + scratch = role.rc.scratch + if scratch.chatting_with: + return False + if "" in scratch.act_address: + return False + + # Recall that retrieved takes the following form: + # dictionary {["curr_event"] = } + curr_event = retrieved["curr_event"] + logger.info(f"Role: {role.name} _should_react curr_event.subject: {curr_event.subject}") + + if ":" not in curr_event.subject: + # this is a role event. + if await lets_talk(role, roles[curr_event.subject], retrieved): + return f"chat with {curr_event.subject}" + react_mode = await lets_react(role, roles[curr_event.subject], retrieved) + return react_mode + return False + + +async def _chat_react(role: "STRole", reaction_mode: str, roles: dict["STRole"]): + # There are two roles -- the role who is initiating the conversation + # and the role who is the target. We get the role instances here. + init_role = role + target_role = roles[reaction_mode[9:].strip()] + + # Actually creating the conversation here. + convo, duration_min = await generate_convo(init_role, target_role) # 2222 + convo_summary = await generate_convo_summary(convo) + inserted_act = convo_summary + inserted_act_dur = duration_min + + act_start_time = target_role.rc.scratch.act_start_time + + curr_time = target_role.rc.scratch.curr_time + if curr_time.second != 0: + temp_curr_time = curr_time + datetime.timedelta(seconds=60 - curr_time.second) + chatting_end_time = temp_curr_time + datetime.timedelta(minutes=inserted_act_dur) + else: + chatting_end_time = curr_time + datetime.timedelta(minutes=inserted_act_dur) + + for role, p in [("init", init_role), ("target", target_role)]: + if role == "init": + act_address = f" {target_role.name}" + act_event = (p.name, "chat with", target_role.name) + chatting_with = target_role.name + chatting_with_buffer = {} + chatting_with_buffer[target_role.name] = 800 + elif role == "target": + act_address = f" {init_role.name}" + act_event = (p.name, "chat with", init_role.name) + chatting_with = init_role.name + chatting_with_buffer = {} + chatting_with_buffer[init_role.name] = 800 + + act_pronunciatio = "💬" + act_obj_description = None + act_obj_pronunciatio = None + act_obj_event = (None, None, None) + + await _create_react( + p, + inserted_act, + inserted_act_dur, + act_address, + act_event, + chatting_with, + convo, + chatting_with_buffer, + chatting_end_time, + act_pronunciatio, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time, + ) + + +async def _create_react( + role: "STRole", + inserted_act: str, + inserted_act_dur: int, + act_address: str, + act_event: Tuple, + chatting_with: str, + chat: list, + chatting_with_buffer: dict, + chatting_end_time: datetime, + act_pronunciatio: str, + act_obj_description: str, + act_obj_pronunciatio: str, + act_obj_event: Tuple, + act_start_time=None, +): + p = role + scratch = role.rc.scratch + + min_sum = 0 + for i in range(scratch.get_f_daily_schedule_hourly_org_index()): + min_sum += scratch.f_daily_schedule_hourly_org[i][1] + start_hour = int(min_sum / 60) + + if scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] >= 120: + end_hour = ( + start_hour + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] / 60 + ) + + elif ( + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] + + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index() + 1][1] + ): + end_hour = start_hour + ( + ( + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] + + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index() + 1][1] + ) + / 60 + ) + + else: + end_hour = start_hour + 2 + end_hour = int(end_hour) + + dur_sum = 0 + count = 0 + start_index = None + end_index = None + for act, dur in scratch.f_daily_schedule: + if dur_sum >= start_hour * 60 and start_index is None: + start_index = count + if dur_sum >= end_hour * 60 and end_index is None: + end_index = count + dur_sum += dur + count += 1 + + ret = await generate_new_decomp_schedule(p, inserted_act, inserted_act_dur, start_hour, end_hour) + scratch.f_daily_schedule[start_index:end_index] = ret + scratch.add_new_action( + act_address, + inserted_act_dur, + inserted_act, + act_pronunciatio, + act_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time, + ) + + +async def _wait_react(role: "STRole", reaction_mode: str): + scratch = role.rc.scratch + + inserted_act = f'waiting to start {scratch.act_description.split("(")[-1][:-1]}' + end_time = datetime.datetime.strptime(reaction_mode[6:].strip(), "%B %d, %Y, %H:%M:%S") + inserted_act_dur = ( + (end_time.minute + end_time.hour * 60) - (scratch.curr_time.minute + scratch.curr_time.hour * 60) + 1 + ) + + act_address = f" {scratch.curr_tile[0]} {scratch.curr_tile[1]}" + act_event = (role.name, "waiting to start", scratch.act_description.split("(")[-1][:-1]) + chatting_with = None + chat = None + chatting_with_buffer = None + chatting_end_time = None + + act_pronunciatio = "⌛" + act_obj_description = None + act_obj_pronunciatio = None + act_obj_event = (None, None, None) + + await _create_react( + role, + inserted_act, + inserted_act_dur, + act_address, + act_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_pronunciatio, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + ) + + +async def generate_convo(init_role: "STRole", target_role: "STRole") -> Union[list, int]: + convo = await agent_conversation(init_role, target_role) + all_utt = "" + + for row in convo: + speaker = row[0] + utt = row[1] + all_utt += f"{speaker}: {utt}\n" + + convo_length = math.ceil(int(len(all_utt) / 8) / 30) + + return convo, convo_length + + +async def generate_convo_summary(conv: list[list[str]]) -> str: + conv_summary = await SummarizeConv().run(conv) + return conv_summary + + +async def generate_new_decomp_schedule( + role: "STRole", inserted_act: str, inserted_act_dur: int, start_hour: int, end_hour: int +): + # Step 1: Setting up the core variables for the function. + #

is the role whose schedule we are editing right now. + scratch = role.rc.scratch + # indicates the number of minutes that have passed today. + today_min_pass = int(scratch.curr_time.hour) * 60 + int(scratch.curr_time.minute) + 1 + + # Step 2: We need to create and . + main_act_dur = [] + truncated_act_dur = [] + dur_sum = 0 # duration sum + count = 0 # enumerate count + truncated_fin = False + + logger.debug(f"DEBUG::: {scratch.name}") + for act, dur in scratch.f_daily_schedule: + if (dur_sum >= start_hour * 60) and (dur_sum < end_hour * 60): + main_act_dur += [[act, dur]] + if dur_sum <= today_min_pass: + truncated_act_dur += [[act, dur]] + elif dur_sum > today_min_pass and not truncated_fin: + # We need to insert that last act, duration list like this one: + # e.g., ['wakes up and completes her morning routine (wakes up...)', 2] + truncated_act_dur += [[scratch.f_daily_schedule[count][0], dur_sum - today_min_pass]] + truncated_act_dur[-1][-1] -= ( + dur_sum - today_min_pass + ) # DEC 7 DEBUG;.. is the +1 the right thing to do??? + # DEC 7 DEBUG;.. is the +1 the right thing to do??? + # truncated_act_dur[-1][-1] -= (dur_sum - today_min_pass + 1) + logger.debug(f"DEBUG::: {truncated_act_dur}") + + # DEC 7 DEBUG;.. is the +1 the right thing to do??? + # truncated_act_dur[-1][-1] -= (dur_sum - today_min_pass) + truncated_fin = True + dur_sum += dur + count += 1 + + main_act_dur = main_act_dur + + x = ( + truncated_act_dur[-1][0].split("(")[0].strip() + + " (on the way to " + + truncated_act_dur[-1][0].split("(")[-1][:-1] + + ")" + ) + truncated_act_dur[-1][0] = x + + if "(" in truncated_act_dur[-1][0]: + inserted_act = truncated_act_dur[-1][0].split("(")[0].strip() + " (" + inserted_act + ")" + + # To do inserted_act_dur+1 below is an important decision but I'm not sure + # if I understand the full extent of its implications. Might want to + # revisit. + truncated_act_dur += [[inserted_act, inserted_act_dur]] + start_time_hour = datetime.datetime(2022, 10, 31, 0, 0) + datetime.timedelta(hours=start_hour) + end_time_hour = datetime.datetime(2022, 10, 31, 0, 0) + datetime.timedelta(hours=end_hour) + + return await NewDecompSchedule().run( + role, main_act_dur, truncated_act_dur, start_time_hour, end_time_hour, inserted_act, inserted_act_dur + ) + + +async def _long_term_planning(role: "STRole", new_day: bool): + """ + Formulates the role's daily long-term plan if it is the start of a new + day. This basically has two components: first, we create the wake-up hour, + and second, we create the hourly schedule based on it. + INPUT + new_day: Indicates whether the current time signals a "First day", + "New day", or False (for neither). This is important because we + create the roles' long term planning on the new day. + """ + # We start by creating the wake up hour for the role. + wake_up_hour = await WakeUp().run(role) + wake_up_hour = int(wake_up_hour) + logger.info(f"Role: {role.name} long_term_planning, wake_up_hour: {wake_up_hour}") + + # When it is a new day, we start by creating the daily_req of the role. + # Note that the daily_req is a list of strings that describe the role's + # day in broad strokes. + if new_day == "First day": + # Bootstrapping the daily plan for the start of then generation: + # if this is the start of generation (so there is no previous day's + # daily requirement, or if we are on a new day, we want to create a new + # set of daily requirements. + role.scratch.daily_req = await GenDailySchedule().run(role, wake_up_hour) + logger.info(f"Role: {role.name} daily requirements: {role.scratch.daily_req}") + elif new_day == "New day": + revise_identity(role) + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - TODO + # We need to create a new daily_req here... + role.scratch.daily_req = role.scratch.daily_req + + # Based on the daily_req, we create an hourly schedule for the role, + # which is a list of todo items with a time duration (in minutes) that + # add up to 24 hours. + role.scratch.f_daily_schedule = await GenHourlySchedule().run(role, wake_up_hour) + logger.info(f"Role: {role.name} f_daily_schedule: {role.scratch.f_daily_schedule}") + role.scratch.f_daily_schedule_hourly_org = role.scratch.f_daily_schedule[:] + + # Added March 4 -- adding plan to the memory. + thought = f"This is {role.scratch.name}'s plan for {role.scratch.curr_time.strftime('%A %B %d')}:" + for i in role.scratch.daily_req: + thought += f" {i}," + thought = thought[:-1] + "." + created = role.scratch.curr_time + expiration = role.scratch.curr_time + datetime.timedelta(days=30) + s, p, o = (role.scratch.name, "plan", role.scratch.curr_time.strftime("%A %B %d")) + keywords = set(["plan"]) + thought_poignancy = 5 + thought_embedding_pair = (thought, get_embedding(thought)) + role.a_mem.add_thought( + created, expiration, s, p, o, thought, keywords, thought_poignancy, thought_embedding_pair, None + ) + + +async def _determine_action(role: "STRole"): + """ + Creates the next action sequence for the role. + The main goal of this function is to run "add_new_action" on the role's + scratch space, which sets up all the action related variables for the next + action. + As a part of this, the role may need to decompose its hourly schedule as + needed. + INPUT + role: Current instance whose action we are determining. + """ + + def determine_decomp(act_desp, act_dura): + """ + Given an action description and its duration, we determine whether we need + to decompose it. If the action is about the agent sleeping, we generally + do not want to decompose it, so that's what we catch here. + + INPUT: + act_desp: the description of the action (e.g., "sleeping") + act_dura: the duration of the action in minutes. + OUTPUT: + a boolean. True if we need to decompose, False otherwise. + """ + if "sleep" not in act_desp and "bed" not in act_desp: + return True + elif "sleeping" in act_desp or "asleep" in act_desp or "in bed" in act_desp: + return False + elif "sleep" in act_desp or "bed" in act_desp: + if act_dura > 60: + return False + return True + + # The goal of this function is to get us the action associated with + # . As a part of this, we may need to decompose some large + # chunk actions. + # Importantly, we try to decompose at least two hours worth of schedule at + # any given point. + curr_index = role.scratch.get_f_daily_schedule_index() + curr_index_60 = role.scratch.get_f_daily_schedule_index(advance=60) + + logger.info(f"f_daily_schedule: {role.scratch.f_daily_schedule}") + # * Decompose * + # During the first hour of the day, we need to decompose two hours + # sequence. We do that here. + if curr_index == 0: + # This portion is invoked if it is the first hour of the day. + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index] + if act_dura >= 60: + # We decompose if the next action is longer than an hour, and fits the + # criteria described in determine_decomp. + if determine_decomp(act_desp, act_dura): + role.scratch.f_daily_schedule[curr_index : curr_index + 1] = await TaskDecomp().run( + role, act_desp, act_dura + ) + if curr_index_60 + 1 < len(role.scratch.f_daily_schedule): + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index_60 + 1] + if act_dura >= 60: + if determine_decomp(act_desp, act_dura): + role.scratch.f_daily_schedule[curr_index_60 + 1 : curr_index_60 + 2] = await TaskDecomp().run( + role, act_desp, act_dura + ) + + if curr_index_60 < len(role.scratch.f_daily_schedule): + # If it is not the first hour of the day, this is always invoked (it is + # also invoked during the first hour of the day -- to double up so we can + # decompose two hours in one go). Of course, we need to have something to + # decompose as well, so we check for that too. + if role.scratch.curr_time.hour < 23: + # And we don't want to decompose after 11 pm. + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index_60] + if act_dura >= 60: + if determine_decomp(act_desp, act_dura): + role.scratch.f_daily_schedule[curr_index_60 : curr_index_60 + 1] = await TaskDecomp().run( + role, act_desp, act_dura + ) + # * End of Decompose * + + # Generate an instance from the action description and duration. By + # this point, we assume that all the relevant actions are decomposed and + # ready in f_daily_schedule. + logger.debug("DEBUG LJSDLFSKJF") + for i in role.scratch.f_daily_schedule: + logger.debug(i) + logger.debug(curr_index) + logger.debug(len(role.scratch.f_daily_schedule)) + logger.debug(role.scratch.name) + + # 1440 + x_emergency = 0 + for i in role.scratch.f_daily_schedule: + x_emergency += i[1] + + if 1440 - x_emergency > 0: + logger.info(f"x_emergency__AAA: {x_emergency}") + role.scratch.f_daily_schedule += [["sleeping", 1440 - x_emergency]] + + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index] + + new_action_details = await GenActionDetails().run(role, act_desp, act_dura) + # Adding the action to role's queue. + role.scratch.add_new_action(**new_action_details) + + +def revise_identity(role: "STRole"): + p_name = role.scratch.name + + focal_points = [ + f"{p_name}'s plan for {role.scratch.get_str_curr_date_str()}.", + f"Important recent events for {p_name}'s life.", + ] + retrieved = new_agent_retrieve(role, focal_points) + + statements = "[Statements]\n" + for key, val in retrieved.items(): + for i in val: + statements += f"{i.created.strftime('%A %B %d -- %H:%M %p')}: {i.embedding_key}\n" + + plan_prompt = statements + "\n" + plan_prompt += f"Given the statements above, is there anything that {p_name} should remember as they plan for" + plan_prompt += f" *{role.scratch.curr_time.strftime('%A %B %d')}*? " + plan_prompt += "If there is any scheduling information, be as specific as possible (include date, time, and location if stated in the statement)\n\n" + plan_prompt += f"Write the response from {p_name}'s perspective." + plan_note = LLM().ask(plan_prompt) + + thought_prompt = statements + "\n" + thought_prompt += ( + f"Given the statements above, how might we summarize {p_name}'s feelings about their days up to now?\n\n" + ) + thought_prompt += f"Write the response from {p_name}'s perspective." + thought_note = LLM().ask(thought_prompt) + + currently_prompt = ( + f"{p_name}'s status from {(role.scratch.curr_time - datetime.timedelta(days=1)).strftime('%A %B %d')}:\n" + ) + currently_prompt += f"{role.scratch.currently}\n\n" + currently_prompt += f"{p_name}'s thoughts at the end of {(role.scratch.curr_time - datetime.timedelta(days=1)).strftime('%A %B %d')}:\n" + currently_prompt += (plan_note + thought_note).replace("\n", "") + "\n\n" + currently_prompt += f"It is now {role.scratch.curr_time.strftime('%A %B %d')}. Given the above, write {p_name}'s status for {role.scratch.curr_time.strftime('%A %B %d')} that reflects {p_name}'s thoughts at the end of {(role.scratch.curr_time - datetime.timedelta(days=1)).strftime('%A %B %d')}. Write this in third-person talking about {p_name}." + currently_prompt += "If there is any scheduling information, be as specific as possible (include date, time, and location if stated in the statement).\n\n" + currently_prompt += "Follow this format below:\nStatus: " + new_currently = LLM().ask(currently_prompt) + + role.scratch.currently = new_currently + + daily_req_prompt = role.scratch.get_str_iss() + "\n" + daily_req_prompt += f"Today is {role.scratch.curr_time.strftime('%A %B %d')}. Here is {role.scratch.name}'s plan today in broad-strokes (with the time of the day. e.g., have a lunch at 12:00 pm, watch TV from 7 to 8 pm).\n\n" + daily_req_prompt += "Follow this format (the list should have 4~6 items but no more):\n" + daily_req_prompt += "1. wake up and complete the morning routine at

\ No newline at end of file diff --git a/notebook_dir/metagpt_yusin/utils/json_to_markdown.py b/notebook_dir/metagpt_yusin/utils/json_to_markdown.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b40c6f6bbce0a0d0b8a435f6310e1bb3e80605 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/json_to_markdown.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/11 11:50 +@Author : femto Zheng +@File : json_to_markdown.py +""" + + +# since we original write docs/*.md in markdown format, so I convert json back to markdown +def json_to_markdown(data, depth=2): + """ + Convert a JSON object to Markdown with headings for keys and lists for arrays, supporting nested objects. + + Args: + data: JSON object (dictionary) or value. + depth (int): Current depth level for Markdown headings. + + Returns: + str: Markdown representation of the JSON data. + """ + markdown = "" + + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, list): + # Handle JSON arrays + markdown += "#" * depth + f" {key}\n\n" + items = [str(item) for item in value] + markdown += "- " + "\n- ".join(items) + "\n\n" + elif isinstance(value, dict): + # Handle nested JSON objects + markdown += "#" * depth + f" {key}\n\n" + markdown += json_to_markdown(value, depth + 1) + else: + # Handle other values + markdown += "#" * depth + f" {key}\n\n{value}\n\n" + else: + # Handle non-dictionary JSON data + markdown = str(data) + + return markdown diff --git a/notebook_dir/metagpt_yusin/utils/make_sk_kernel.py b/notebook_dir/metagpt_yusin/utils/make_sk_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..f9028a27c43fcca5458415e9146227b806e279c9 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/make_sk_kernel.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 12:29 +@Author : femto Zheng +@File : make_sk_kernel.py +""" +import semantic_kernel as sk +from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import ( + AzureChatCompletion, +) +from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import ( + OpenAIChatCompletion, +) + +from metagpt_yusin.config2 import config + + +def make_sk_kernel(): + kernel = sk.Kernel() + if llm := config.get_azure_llm(): + kernel.add_chat_service( + "chat_completion", + AzureChatCompletion(llm.model, llm.base_url, llm.api_key), + ) + elif llm := config.get_openai_llm(): + kernel.add_chat_service( + "chat_completion", + OpenAIChatCompletion(llm.model, llm.api_key), + ) + + return kernel diff --git a/notebook_dir/metagpt_yusin/utils/mermaid.py b/notebook_dir/metagpt_yusin/utils/mermaid.py new file mode 100644 index 0000000000000000000000000000000000000000..68e3c6549f3f1db2c478ffd92cb594ff2a20b410 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/mermaid.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/4 10:53 +@Author : alexanderwu alitrack +@File : mermaid.py +""" +import asyncio +import os +from pathlib import Path + +from metagpt_yusin.config2 import config +from metagpt_yusin.logs import logger +from metagpt_yusin.utils.common import awrite, check_cmd_exists + + +async def mermaid_to_file(engine, mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: + """suffix: png/svg/pdf + + :param mermaid_code: mermaid code + :param output_file_without_suffix: output filename + :param width: + :param height: + :return: 0 if succeed, -1 if failed + """ + # Write the Mermaid code to a temporary file + dir_name = os.path.dirname(output_file_without_suffix) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name) + tmp = Path(f"{output_file_without_suffix}.mmd") + await awrite(filename=tmp, data=mermaid_code) + + if engine == "nodejs": + if check_cmd_exists(config.mermaid.path) != 0: + logger.warning( + "RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc," + "or consider changing engine to `playwright`, `pyppeteer`, or `ink`." + ) + return -1 + + for suffix in ["pdf", "svg", "png"]: + output_file = f"{output_file_without_suffix}.{suffix}" + # Call the `mmdc` command to convert the Mermaid code to a PNG + logger.info(f"Generating {output_file}..") + + if config.mermaid.puppeteer_config: + commands = [ + config.mermaid.path, + "-p", + config.mermaid.puppeteer_config, + "-i", + str(tmp), + "-o", + output_file, + "-w", + str(width), + "-H", + str(height), + ] + else: + commands = [config.mermaid.path, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)] + process = await asyncio.create_subprocess_shell( + " ".join(commands), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + if stdout: + logger.info(stdout.decode()) + if stderr: + logger.warning(stderr.decode()) + else: + if engine == "playwright": + from metagpt_yusin.utils.mmdc_playwright import mermaid_to_file + + return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height) + elif engine == "pyppeteer": + from metagpt_yusin.utils.mmdc_pyppeteer import mermaid_to_file + + return await mermaid_to_file(mermaid_code, output_file_without_suffix, width, height) + elif engine == "ink": + from metagpt_yusin.utils.mmdc_ink import mermaid_to_file + + return await mermaid_to_file(mermaid_code, output_file_without_suffix) + elif engine == "none": + return 0 + else: + logger.warning(f"Unsupported mermaid engine: {engine}") + return 0 + + +MMC1 = """ +classDiagram + class Main { + -SearchEngine search_engine + +main() str + } + class SearchEngine { + -Index index + -Ranking ranking + -Summary summary + +search(query: str) str + } + class Index { + -KnowledgeBase knowledge_base + +create_index(data: dict) + +query_index(query: str) list + } + class Ranking { + +rank_results(results: list) list + } + class Summary { + +summarize_results(results: list) str + } + class KnowledgeBase { + +update(data: dict) + +fetch_data(query: str) dict + } + Main --> SearchEngine + SearchEngine --> Index + SearchEngine --> Ranking + SearchEngine --> Summary + Index --> KnowledgeBase +""" + +MMC2 = """ +sequenceDiagram + participant M as Main + participant SE as SearchEngine + participant I as Index + participant R as Ranking + participant S as Summary + participant KB as KnowledgeBase + M->>SE: search(query) + SE->>I: query_index(query) + I->>KB: fetch_data(query) + KB-->>I: return data + I-->>SE: return results + SE->>R: rank_results(results) + R-->>SE: return ranked_results + SE->>S: summarize_results(ranked_results) + S-->>SE: return summary + SE-->>M: return summary +""" diff --git a/notebook_dir/metagpt_yusin/utils/mmdc_ink.py b/notebook_dir/metagpt_yusin/utils/mmdc_ink.py new file mode 100644 index 0000000000000000000000000000000000000000..1acf32162848b5e3c590af8e0094137f0d175919 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/mmdc_ink.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/4 16:12 +@Author : alitrack +@File : mermaid.py +""" +import base64 + +from aiohttp import ClientError, ClientSession + +from metagpt_yusin.logs import logger + + +async def mermaid_to_file(mermaid_code, output_file_without_suffix): + """suffix: png/svg + :param mermaid_code: mermaid code + :param output_file_without_suffix: output filename without suffix + :return: 0 if succeed, -1 if failed + """ + encoded_string = base64.b64encode(mermaid_code.encode()).decode() + + for suffix in ["svg", "png"]: + output_file = f"{output_file_without_suffix}.{suffix}" + path_type = "svg" if suffix == "svg" else "img" + url = f"https://mermaid.ink/{path_type}/{encoded_string}" + async with ClientSession() as session: + try: + async with session.get(url) as response: + if response.status == 200: + text = await response.content.read() + with open(output_file, "wb") as f: + f.write(text) + logger.info(f"Generating {output_file}..") + else: + logger.error(f"Failed to generate {output_file}") + return -1 + except ClientError as e: + logger.error(f"network error: {e}") + return -1 + return 0 diff --git a/notebook_dir/metagpt_yusin/utils/mmdc_playwright.py b/notebook_dir/metagpt_yusin/utils/mmdc_playwright.py new file mode 100644 index 0000000000000000000000000000000000000000..80256cc48693d2f0877ed82e837cb42d5502dfce --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/mmdc_playwright.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/4 16:12 +@Author : Steven Lee +@File : mmdc_playwright.py +""" + +import os +from urllib.parse import urljoin + +from playwright.async_api import async_playwright + +from metagpt_yusin.logs import logger + + +async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: + """ + Converts the given Mermaid code to various output formats and saves them to files. + + Args: + mermaid_code (str): The Mermaid code to convert. + output_file_without_suffix (str): The output file name without the file extension. + width (int, optional): The width of the output image in pixels. Defaults to 2048. + height (int, optional): The height of the output image in pixels. Defaults to 2048. + + Returns: + int: Returns 1 if the conversion and saving were successful, -1 otherwise. + """ + suffixes = ["png", "svg", "pdf"] + __dirname = os.path.dirname(os.path.abspath(__file__)) + + async with async_playwright() as p: + browser = await p.chromium.launch() + device_scale_factor = 1.0 + context = await browser.new_context( + viewport={"width": width, "height": height}, + device_scale_factor=device_scale_factor, + ) + page = await context.new_page() + + async def console_message(msg): + logger.info(msg.text) + + page.on("console", console_message) + + try: + await page.set_viewport_size({"width": width, "height": height}) + + mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html")) + mermaid_html_url = urljoin("file:", mermaid_html_path) + await page.goto(mermaid_html_url) + await page.wait_for_load_state("networkidle") + + await page.wait_for_selector("div#container", state="attached") + # mermaid_config = {} + background_color = "#ffffff" + # my_css = "" + await page.evaluate(f'document.body.style.background = "{background_color}";') + + # metadata = await page.evaluate( + # """async ([definition, mermaidConfig, myCSS, backgroundColor]) => { + # const { mermaid, zenuml } = globalThis; + # await mermaid.registerExternalDiagrams([zenuml]); + # mermaid.initialize({ startOnLoad: false, ...mermaidConfig }); + # const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container')); + # document.getElementById('container').innerHTML = svg; + # const svgElement = document.querySelector('svg'); + # svgElement.style.backgroundColor = backgroundColor; + # + # if (myCSS) { + # const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); + # style.appendChild(document.createTextNode(myCSS)); + # svgElement.appendChild(style); + # } + # + # }""", + # [mermaid_code, mermaid_config, my_css, background_color], + # ) + + if "svg" in suffixes: + svg_xml = await page.evaluate( + """() => { + const svg = document.querySelector('svg'); + const xmlSerializer = new XMLSerializer(); + return xmlSerializer.serializeToString(svg); + }""" + ) + logger.info(f"Generating {output_file_without_suffix}.svg..") + with open(f"{output_file_without_suffix}.svg", "wb") as f: + f.write(svg_xml.encode("utf-8")) + + if "png" in suffixes: + clip = await page.evaluate( + """() => { + const svg = document.querySelector('svg'); + const rect = svg.getBoundingClientRect(); + return { + x: Math.floor(rect.left), + y: Math.floor(rect.top), + width: Math.ceil(rect.width), + height: Math.ceil(rect.height) + }; + }""" + ) + await page.set_viewport_size({"width": clip["x"] + clip["width"], "height": clip["y"] + clip["height"]}) + screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device") + logger.info(f"Generating {output_file_without_suffix}.png..") + with open(f"{output_file_without_suffix}.png", "wb") as f: + f.write(screenshot) + if "pdf" in suffixes: + pdf_data = await page.pdf(scale=device_scale_factor) + logger.info(f"Generating {output_file_without_suffix}.pdf..") + with open(f"{output_file_without_suffix}.pdf", "wb") as f: + f.write(pdf_data) + return 0 + except Exception as e: + logger.error(e) + return -1 + finally: + await browser.close() diff --git a/notebook_dir/metagpt_yusin/utils/mmdc_pyppeteer.py b/notebook_dir/metagpt_yusin/utils/mmdc_pyppeteer.py new file mode 100644 index 0000000000000000000000000000000000000000..774eea0c75899ef77196a2c2758a7b74917a3282 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/mmdc_pyppeteer.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/4 16:12 +@Author : alitrack +@File : mmdc_pyppeteer.py +""" +import os +from urllib.parse import urljoin + +from pyppeteer import launch + +from metagpt_yusin.config2 import config +from metagpt_yusin.logs import logger + + +async def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height=2048) -> int: + """ + Converts the given Mermaid code to various output formats and saves them to files. + + Args: + mermaid_code (str): The Mermaid code to convert. + output_file_without_suffix (str): The output file name without the file extension. + width (int, optional): The width of the output image in pixels. Defaults to 2048. + height (int, optional): The height of the output image in pixels. Defaults to 2048. + + Returns: + int: Returns 1 if the conversion and saving were successful, -1 otherwise. + """ + suffixes = ["png", "svg", "pdf"] + __dirname = os.path.dirname(os.path.abspath(__file__)) + + if config.mermaid.pyppeteer_path: + browser = await launch( + headless=True, + executablePath=config.mermaid.pyppeteer_path, + args=["--disable-extensions", "--no-sandbox"], + ) + else: + logger.error("Please set the var mermaid.pyppeteer_path in the config2.yaml.") + return -1 + page = await browser.newPage() + device_scale_factor = 1.0 + + async def console_message(msg): + logger.info(msg.text) + + page.on("console", console_message) + + try: + await page.setViewport(viewport={"width": width, "height": height, "deviceScaleFactor": device_scale_factor}) + + mermaid_html_path = os.path.abspath(os.path.join(__dirname, "index.html")) + mermaid_html_url = urljoin("file:", mermaid_html_path) + await page.goto(mermaid_html_url) + + await page.querySelector("div#container") + # mermaid_config = {} + background_color = "#ffffff" + # my_css = "" + await page.evaluate(f'document.body.style.background = "{background_color}";') + + # metadata = await page.evaluate( + # """async ([definition, mermaidConfig, myCSS, backgroundColor]) => { + # const { mermaid, zenuml } = globalThis; + # await mermaid.registerExternalDiagrams([zenuml]); + # mermaid.initialize({ startOnLoad: false, ...mermaidConfig }); + # const { svg } = await mermaid.render('my-svg', definition, document.getElementById('container')); + # document.getElementById('container').innerHTML = svg; + # const svgElement = document.querySelector('svg'); + # svgElement.style.backgroundColor = backgroundColor; + # + # if (myCSS) { + # const style = document.createElementNS('http://www.w3.org/2000/svg', 'style'); + # style.appendChild(document.createTextNode(myCSS)); + # svgElement.appendChild(style); + # } + # }""", + # [mermaid_code, mermaid_config, my_css, background_color], + # ) + + if "svg" in suffixes: + svg_xml = await page.evaluate( + """() => { + const svg = document.querySelector('svg'); + const xmlSerializer = new XMLSerializer(); + return xmlSerializer.serializeToString(svg); + }""" + ) + logger.info(f"Generating {output_file_without_suffix}.svg..") + with open(f"{output_file_without_suffix}.svg", "wb") as f: + f.write(svg_xml.encode("utf-8")) + + if "png" in suffixes: + clip = await page.evaluate( + """() => { + const svg = document.querySelector('svg'); + const rect = svg.getBoundingClientRect(); + return { + x: Math.floor(rect.left), + y: Math.floor(rect.top), + width: Math.ceil(rect.width), + height: Math.ceil(rect.height) + }; + }""" + ) + await page.setViewport( + { + "width": clip["x"] + clip["width"], + "height": clip["y"] + clip["height"], + "deviceScaleFactor": device_scale_factor, + } + ) + screenshot = await page.screenshot(clip=clip, omit_background=True, scale="device") + logger.info(f"Generating {output_file_without_suffix}.png..") + with open(f"{output_file_without_suffix}.png", "wb") as f: + f.write(screenshot) + if "pdf" in suffixes: + pdf_data = await page.pdf(scale=device_scale_factor) + logger.info(f"Generating {output_file_without_suffix}.pdf..") + with open(f"{output_file_without_suffix}.pdf", "wb") as f: + f.write(pdf_data) + return 0 + except Exception as e: + logger.error(e) + return -1 + finally: + await browser.close() diff --git a/notebook_dir/metagpt_yusin/utils/parse_docstring.py b/notebook_dir/metagpt_yusin/utils/parse_docstring.py new file mode 100644 index 0000000000000000000000000000000000000000..63c0e689098b19a1b27be324b80fba8ab64b5404 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/parse_docstring.py @@ -0,0 +1,43 @@ +import re +from typing import Tuple + + +def remove_spaces(text): + return re.sub(r"\s+", " ", text).strip() + + +class DocstringParser: + @staticmethod + def parse(docstring: str) -> Tuple[str, str]: + """Parse the docstring and return the overall description and the parameter description. + + Args: + docstring (str): The docstring to be parsed. + + Returns: + Tuple[str, str]: A tuple of (overall description, parameter description) + """ + + +class reSTDocstringParser(DocstringParser): + """A parser for reStructuredText (reST) docstring""" + + +class GoogleDocstringParser(DocstringParser): + """A parser for Google-stype docstring""" + + @staticmethod + def parse(docstring: str) -> Tuple[str, str]: + if not docstring: + return "", "" + + docstring = remove_spaces(docstring) + + if "Args:" in docstring: + overall_desc, param_desc = docstring.split("Args:") + param_desc = "Args:" + param_desc + else: + overall_desc = docstring + param_desc = "" + + return overall_desc, param_desc diff --git a/notebook_dir/metagpt_yusin/utils/parse_html.py b/notebook_dir/metagpt_yusin/utils/parse_html.py new file mode 100644 index 0000000000000000000000000000000000000000..65aa3f2369095d2c4137cf204252d2ed0b4184c7 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/parse_html.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +from __future__ import annotations + +from typing import Generator, Optional +from urllib.parse import urljoin, urlparse + +from bs4 import BeautifulSoup +from pydantic import BaseModel, PrivateAttr + + +class WebPage(BaseModel): + inner_text: str + html: str + url: str + + _soup: Optional[BeautifulSoup] = PrivateAttr(default=None) + _title: Optional[str] = PrivateAttr(default=None) + + @property + def soup(self) -> BeautifulSoup: + if self._soup is None: + self._soup = BeautifulSoup(self.html, "html.parser") + return self._soup + + @property + def title(self): + if self._title is None: + title_tag = self.soup.find("title") + self._title = title_tag.text.strip() if title_tag is not None else "" + return self._title + + def get_links(self) -> Generator[str, None, None]: + for i in self.soup.find_all("a", href=True): + url = i["href"] + result = urlparse(url) + if not result.scheme and result.path: + yield urljoin(self.url, url) + elif url.startswith(("http://", "https://")): + yield urljoin(self.url, url) + + +def get_html_content(page: str, base: str): + soup = _get_soup(page) + + return soup.get_text(strip=True) + + +def _get_soup(page: str): + soup = BeautifulSoup(page, "html.parser") + # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup + for s in soup(["style", "script", "[document]", "head", "title"]): + s.extract() + + return soup diff --git a/notebook_dir/metagpt_yusin/utils/project_repo.py b/notebook_dir/metagpt_yusin/utils/project_repo.py new file mode 100644 index 0000000000000000000000000000000000000000..7b5415970b9ca822bce433ba3e7cca8f7c7e8f82 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/project_repo.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/8 +@Author : mashenquan +@File : project_repo.py +@Desc : Wrapper for GitRepository and FileRepository of project. + Implementation of Chapter 4.6 of https://deepwisdom.feishu.cn/wiki/CUK4wImd7id9WlkQBNscIe9cnqh +""" +from __future__ import annotations + +from pathlib import Path + +from metagpt_yusin.const import ( + CLASS_VIEW_FILE_REPO, + CODE_PLAN_AND_CHANGE_FILE_REPO, + CODE_PLAN_AND_CHANGE_PDF_FILE_REPO, + CODE_SUMMARIES_FILE_REPO, + CODE_SUMMARIES_PDF_FILE_REPO, + COMPETITIVE_ANALYSIS_FILE_REPO, + DATA_API_DESIGN_FILE_REPO, + DOCS_FILE_REPO, + GRAPH_REPO_FILE_REPO, + PRD_PDF_FILE_REPO, + PRDS_FILE_REPO, + REQUIREMENT_FILENAME, + RESOURCES_FILE_REPO, + SD_OUTPUT_FILE_REPO, + SEQ_FLOW_FILE_REPO, + SYSTEM_DESIGN_FILE_REPO, + SYSTEM_DESIGN_PDF_FILE_REPO, + TASK_FILE_REPO, + TASK_PDF_FILE_REPO, + TEST_CODES_FILE_REPO, + TEST_OUTPUTS_FILE_REPO, + VISUAL_GRAPH_REPO_FILE_REPO, +) +from metagpt_yusin.utils.file_repository import FileRepository +from metagpt_yusin.utils.git_repository import GitRepository + + +class DocFileRepositories(FileRepository): + prd: FileRepository + system_design: FileRepository + task: FileRepository + code_summary: FileRepository + graph_repo: FileRepository + class_view: FileRepository + code_plan_and_change: FileRepository + + def __init__(self, git_repo): + super().__init__(git_repo=git_repo, relative_path=DOCS_FILE_REPO) + + self.prd = git_repo.new_file_repository(relative_path=PRDS_FILE_REPO) + self.system_design = git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO) + self.task = git_repo.new_file_repository(relative_path=TASK_FILE_REPO) + self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_FILE_REPO) + self.graph_repo = git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO) + self.class_view = git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO) + self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_FILE_REPO) + + +class ResourceFileRepositories(FileRepository): + competitive_analysis: FileRepository + data_api_design: FileRepository + seq_flow: FileRepository + system_design: FileRepository + prd: FileRepository + api_spec_and_task: FileRepository + code_summary: FileRepository + sd_output: FileRepository + code_plan_and_change: FileRepository + graph_repo: FileRepository + + def __init__(self, git_repo): + super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO) + + self.competitive_analysis = git_repo.new_file_repository(relative_path=COMPETITIVE_ANALYSIS_FILE_REPO) + self.data_api_design = git_repo.new_file_repository(relative_path=DATA_API_DESIGN_FILE_REPO) + self.seq_flow = git_repo.new_file_repository(relative_path=SEQ_FLOW_FILE_REPO) + self.system_design = git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_PDF_FILE_REPO) + self.prd = git_repo.new_file_repository(relative_path=PRD_PDF_FILE_REPO) + self.api_spec_and_task = git_repo.new_file_repository(relative_path=TASK_PDF_FILE_REPO) + self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO) + self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO) + self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO) + self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO) + + +class ProjectRepo(FileRepository): + def __init__(self, root: str | Path | GitRepository): + if isinstance(root, str) or isinstance(root, Path): + git_repo_ = GitRepository(local_path=Path(root)) + elif isinstance(root, GitRepository): + git_repo_ = root + else: + raise ValueError("Invalid root") + super().__init__(git_repo=git_repo_, relative_path=Path(".")) + self._git_repo = git_repo_ + self.docs = DocFileRepositories(self._git_repo) + self.resources = ResourceFileRepositories(self._git_repo) + self.tests = self._git_repo.new_file_repository(relative_path=TEST_CODES_FILE_REPO) + self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO) + self._srcs_path = None + self.code_files_exists() + + def __str__(self): + repo_str = f"ProjectRepo({self._git_repo.workdir})" + docs_str = f"Docs({self.docs.all_files})" + srcs_str = f"Srcs({self.srcs.all_files})" + return f"{repo_str}\n{docs_str}\n{srcs_str}" + + @property + async def requirement(self): + return await self.docs.get(filename=REQUIREMENT_FILENAME) + + @property + def git_repo(self) -> GitRepository: + return self._git_repo + + @property + def workdir(self) -> Path: + return Path(self.git_repo.workdir) + + @property + def srcs(self) -> FileRepository: + if not self._srcs_path: + raise ValueError("Call with_srcs first.") + return self._git_repo.new_file_repository(self._srcs_path) + + def code_files_exists(self) -> bool: + git_workdir = self.git_repo.workdir + src_workdir = git_workdir / git_workdir.name + if not src_workdir.exists(): + return False + code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files + if not code_files: + return False + return bool(code_files) + + def with_src_path(self, path: str | Path) -> ProjectRepo: + try: + self._srcs_path = Path(path).relative_to(self.workdir) + except ValueError: + self._srcs_path = Path(path) + return self + + @property + def src_relative_path(self) -> Path | None: + return self._srcs_path diff --git a/notebook_dir/metagpt_yusin/utils/pycst.py b/notebook_dir/metagpt_yusin/utils/pycst.py new file mode 100644 index 0000000000000000000000000000000000000000..a26ba70ffbcfb90b8025f16c5c5e2c0227f2c418 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/pycst.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from typing import Union + +import libcst as cst +from libcst._nodes.module import Module + +DocstringNode = Union[cst.Module, cst.ClassDef, cst.FunctionDef] + + +def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: + """Extracts the docstring from the body of a node. + + Args: + body: The body of a node. + + Returns: + The docstring statement if it exists, None otherwise. + """ + if isinstance(body, cst.Module): + body = body.body + else: + body = body.body.body + + if not body: + return + + statement = body[0] + if not isinstance(statement, cst.SimpleStatementLine): + return + + expr = statement + while isinstance(expr, (cst.BaseSuite, cst.SimpleStatementLine)): + if len(expr.body) == 0: + return None + expr = expr.body[0] + + if not isinstance(expr, cst.Expr): + return None + + val = expr.value + if not isinstance(val, (cst.SimpleString, cst.ConcatenatedString)): + return None + + evaluated_value = val.evaluated_value + if isinstance(evaluated_value, bytes): + return None + + return statement + + +def has_decorator(node: DocstringNode, name: str) -> bool: + return hasattr(node, "decorators") and any( + (hasattr(i.decorator, "value") and i.decorator.value == name) + or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name) + for i in node.decorators + ) + + +class DocstringCollector(cst.CSTVisitor): + """A visitor class for collecting docstrings from a CST. + + Attributes: + stack: A list to keep track of the current path in the CST. + docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. + """ + + def __init__(self): + self.stack: list[str] = [] + self.docstrings: dict[tuple[str, ...], cst.SimpleStatementLine] = {} + + def visit_Module(self, node: cst.Module) -> bool | None: + self.stack.append("") + + def leave_Module(self, node: cst.Module) -> None: + return self._leave(node) + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_ClassDef(self, node: cst.ClassDef) -> None: + return self._leave(node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_FunctionDef(self, node: cst.FunctionDef) -> None: + return self._leave(node) + + def _leave(self, node: DocstringNode) -> None: + key = tuple(self.stack) + self.stack.pop() + if has_decorator(node, "overload"): + return + + statement = get_docstring_statement(node) + if statement: + self.docstrings[key] = statement + + +class DocstringTransformer(cst.CSTTransformer): + """A transformer class for replacing docstrings in a CST. + + Attributes: + stack: A list to keep track of the current path in the CST. + docstrings: A dictionary mapping paths in the CST to their corresponding docstrings. + """ + + def __init__( + self, + docstrings: dict[tuple[str, ...], cst.SimpleStatementLine], + ): + self.stack: list[str] = [] + self.docstrings = docstrings + + def visit_Module(self, node: cst.Module) -> bool | None: + self.stack.append("") + + def leave_Module(self, original_node: Module, updated_node: Module) -> Module: + return self._leave(original_node, updated_node) + + def visit_ClassDef(self, node: cst.ClassDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.CSTNode: + return self._leave(original_node, updated_node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + self.stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.CSTNode: + return self._leave(original_node, updated_node) + + def _leave(self, original_node: DocstringNode, updated_node: DocstringNode) -> DocstringNode: + key = tuple(self.stack) + self.stack.pop() + + if has_decorator(updated_node, "overload"): + return updated_node + + statement = self.docstrings.get(key) + if not statement: + return updated_node + + original_statement = get_docstring_statement(original_node) + + if isinstance(updated_node, cst.Module): + body = updated_node.body + if original_statement: + return updated_node.with_changes(body=(statement, *body[1:])) + else: + updated_node = updated_node.with_changes(body=(statement, cst.EmptyLine(), *body)) + return updated_node + + body = updated_node.body.body[1:] if original_statement else updated_node.body.body + return updated_node.with_changes(body=updated_node.body.with_changes(body=(statement, *body))) + + +def merge_docstring(code: str, documented_code: str) -> str: + """Merges the docstrings from the documented code into the original code. + + Args: + code: The original code. + documented_code: The documented code. + + Returns: + The original code with the docstrings from the documented code. + """ + code_tree = cst.parse_module(code) + documented_code_tree = cst.parse_module(documented_code) + + visitor = DocstringCollector() + documented_code_tree.visit(visitor) + transformer = DocstringTransformer(visitor.docstrings) + modified_tree = code_tree.visit(transformer) + return modified_tree.code diff --git a/notebook_dir/metagpt_yusin/utils/read_document.py b/notebook_dir/metagpt_yusin/utils/read_document.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fafbc17807f74215c5bba1dcc76268d7f5edfb --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/read_document.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/4/29 15:45 +@Author : alexanderwu +@File : read_document.py +""" + +import docx + + +def read_docx(file_path: str) -> list: + """Open a docx file""" + doc = docx.Document(file_path) + + # Create an empty list to store paragraph contents + paragraphs_list = [] + + # Iterate through the paragraphs in the document and add their content to the list + for paragraph in doc.paragraphs: + paragraphs_list.append(paragraph.text) + + return paragraphs_list diff --git a/notebook_dir/metagpt_yusin/utils/recovery_util.py b/notebook_dir/metagpt_yusin/utils/recovery_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1648d681877d8a7eed6ad8fa12279c52bfdca1c8 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/recovery_util.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# @Date : 12/20/2023 11:07 AM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import json +from datetime import datetime +from pathlib import Path + +import nbformat + +from metagpt_yusin.const import DATA_PATH +from metagpt_yusin.roles.role import Role +from metagpt_yusin.utils.common import read_json_file +from metagpt_yusin.utils.save_code import save_code_file + + +def load_history(save_dir: str = ""): + """ + Load plan and code execution history from the specified save directory. + + Args: + save_dir (str): The directory from which to load the history. + + Returns: + Tuple: A tuple containing the loaded plan and notebook. + """ + + plan_path = Path(save_dir) / "plan.json" + nb_path = Path(save_dir) / "history_nb" / "code.ipynb" + plan = read_json_file(plan_path) + nb = nbformat.read(open(nb_path, "r", encoding="utf-8"), as_version=nbformat.NO_CONVERT) + return plan, nb + + +def save_history(role: Role, save_dir: str = ""): + """ + Save plan and code execution history to the specified directory. + + Args: + role (Role): The role containing the plan and execute_code attributes. + save_dir (str): The directory to save the history. + + Returns: + Path: The path to the saved history directory. + """ + record_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + save_path = DATA_PATH / "output" / f"{record_time}" + + # overwrite exist trajectory + save_path.mkdir(parents=True, exist_ok=True) + + plan = role.planner.plan.dict() + + with open(save_path / "plan.json", "w", encoding="utf-8") as plan_file: + json.dump(plan, plan_file, indent=4, ensure_ascii=False) + + save_code_file(name=Path(record_time), code_context=role.execute_code.nb, file_format="ipynb") + return save_path diff --git a/notebook_dir/metagpt_yusin/utils/redis.py b/notebook_dir/metagpt_yusin/utils/redis.py new file mode 100644 index 0000000000000000000000000000000000000000..a881f96f2492dacfe8ff7a757ce6b97541b3c52a --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/redis.py @@ -0,0 +1,63 @@ +# !/usr/bin/python3 +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/27 +@Author : mashenquan +@File : redis.py +""" +from __future__ import annotations + +import traceback +from datetime import timedelta + +import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/ + +from metagpt_yusin.configs.redis_config import RedisConfig +from metagpt_yusin.logs import logger + + +class Redis: + def __init__(self, config: RedisConfig = None): + self.config = config + self._client = None + + async def _connect(self, force=False): + if self._client and not force: + return True + + try: + self._client = await aioredis.from_url( + self.config.to_url(), + username=self.config.username, + password=self.config.password, + db=self.config.db, + ) + return True + except Exception as e: + logger.warning(f"Redis initialization has failed:{e}") + return False + + async def get(self, key: str) -> bytes | None: + if not await self._connect() or not key: + return None + try: + v = await self._client.get(key) + return v + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") + return None + + async def set(self, key: str, data: str, timeout_sec: int = None): + if not await self._connect() or not key: + return + try: + ex = None if not timeout_sec else timedelta(seconds=timeout_sec) + await self._client.set(key, data, ex=ex) + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") + + async def close(self): + if not self._client: + return + await self._client.close() + self._client = None diff --git a/notebook_dir/metagpt_yusin/utils/reflection.py b/notebook_dir/metagpt_yusin/utils/reflection.py new file mode 100644 index 0000000000000000000000000000000000000000..8b8237ae7a6a7574b31f9264d6a23faa3f5dffed --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/reflection.py @@ -0,0 +1,18 @@ +"""class tools, including method inspection, class attributes, inheritance relationships, etc.""" + + +def check_methods(C, *methods): + """Check if the class has methods. borrow from _collections_abc. + + Useful when implementing implicit interfaces, such as defining an abstract class, isinstance can be used for determination without inheritance. + """ + mro = C.__mro__ + for method in methods: + for B in mro: + if method in B.__dict__: + if B.__dict__[method] is None: + return NotImplemented + break + else: + return NotImplemented + return True diff --git a/notebook_dir/metagpt_yusin/utils/repair_llm_raw_output.py b/notebook_dir/metagpt_yusin/utils/repair_llm_raw_output.py new file mode 100644 index 0000000000000000000000000000000000000000..f23a5616f2eb1e62f0cfa0c14934be4759228908 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/repair_llm_raw_output.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : repair llm raw output with particular conditions + +import copy +from enum import Enum +from typing import Callable, Union + +import regex as re +from tenacity import RetryCallState, retry, stop_after_attempt, wait_fixed + +from metagpt_yusin.config2 import config +from metagpt_yusin.logs import logger +from metagpt_yusin.utils.custom_decoder import CustomDecoder + + +class RepairType(Enum): + CS = "case sensitivity" + RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]` + SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]` + JSON = "json format" + + +def repair_case_sensitivity(output: str, req_key: str) -> str: + """ + usually, req_key is the key name of expected json or markdown content, it won't appear in the value part. + fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually + """ + if req_key in output: + return output + + output_lower = output.lower() + req_key_lower = req_key.lower() + if req_key_lower in output_lower: + # find the sub-part index, and replace it with raw req_key + lidx = output_lower.find(req_key_lower) + source = output[lidx : lidx + len(req_key_lower)] + output = output.replace(source, req_key) + logger.info(f"repair_case_sensitivity: {req_key}") + + return output + + +def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + fix + 1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]` + 2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]` + """ + sc_arr = ["/"] + + if req_key in output: + return output + + for sc in sc_arr: + req_key_pure = req_key.replace(sc, "") + appear_cnt = output.count(req_key_pure) + if req_key_pure in output and appear_cnt > 1: + # req_key with special_character usually in the tail side + ridx = output.rfind(req_key_pure) + output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}" + logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}") + + return output + + +def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str: + """ + implement the req_key pair in the begin or end of the content + req_key format + 1. `[req_key]`, and its pair `[/req_key]` + 2. `[/req_key]`, and its pair `[req_key]` + """ + sc = "/" # special char + if req_key.startswith("[") and req_key.endswith("]"): + if sc in req_key: + left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]` + right_key = req_key + else: + left_key = req_key + right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]` + + if left_key not in output: + output = left_key + "\n" + output + if right_key not in output: + + def judge_potential_json(routput: str, left_key: str) -> Union[str, None]: + ridx = routput.rfind(left_key) + if ridx < 0: + return None + sub_output = routput[ridx:] + idx1 = sub_output.rfind("}") + idx2 = sub_output.rindex("]") + idx = idx1 if idx1 >= idx2 else idx2 + sub_output = sub_output[: idx + 1] + return sub_output + + if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)): + # # avoid [req_key]xx[req_key] case to append [/req_key] + output = output + "\n" + right_key + elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)): + sub_content = judge_potential_json(output, left_key) + output = sub_content + "\n" + right_key + + return output + + +def repair_json_format(output: str) -> str: + """ + fix extra `[` or `}` in the end + """ + output = output.strip() + + if output.startswith("[{"): + output = output[1:] + logger.info(f"repair_json_format: {'[{'}") + elif output.endswith("}]"): + output = output[:-1] + logger.info(f"repair_json_format: {'}]'}") + elif output.startswith("{") and output.endswith("]"): + output = output[:-1] + "}" + + # remove comments in output json string, after json value content, maybe start with #, maybe start with // + arr = output.split("\n") + new_arr = [] + for json_line in arr: + # look for # or // comments and make sure they are not inside the string value + comment_index = -1 + for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", json_line): + if match.group(1): # if the string value + continue + if match.group(2): # if comments + comment_index = match.start(2) + break + # if comments, then delete them + if comment_index != -1: + json_line = json_line[:comment_index].rstrip() + new_arr.append(json_line) + output = "\n".join(new_arr) + return output + + +def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str: + repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]] + for repair_type in repair_types: + if repair_type == RepairType.CS: + output = repair_case_sensitivity(output, req_key) + elif repair_type == RepairType.RKPM: + output = repair_required_key_pair_missing(output, req_key) + elif repair_type == RepairType.SCM: + output = repair_special_character_missing(output, req_key) + elif repair_type == RepairType.JSON: + output = repair_json_format(output) + return output + + +def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str: + """ + in open-source llm model, it usually can't follow the instruction well, the output may be incomplete, + so here we try to repair it and use all repair methods by default. + typical case + 1. case sensitivity + target: "Original Requirements" + output: "Original requirements" + 2. special character missing + target: [/CONTENT] + output: [CONTENT] + 3. json format + target: { xxx } + output: { xxx }] + """ + if not config.repair_llm_output: + return output + + # do the repairation usually for non-openai models + for req_key in req_keys: + output = _repair_llm_raw_output(output=output, req_key=req_key, repair_type=repair_type) + return output + + +def repair_invalid_json(output: str, error: str) -> str: + """ + repair the situation like there are extra chars like + error examples + example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765) + example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266) + """ + pattern = r"line ([0-9]+) column ([0-9]+)" + + matches = re.findall(pattern, error, re.DOTALL) + if len(matches) > 0: + line_no = int(matches[0][0]) - 1 + col_no = int(matches[0][1]) - 1 + + # due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'` + output = output.replace('"""', '"').replace("'''", '"') + arr = output.split("\n") + rline = arr[line_no] # raw line + line = arr[line_no].strip() + # different general problems + if line.endswith("],"): + # problem, redundant char `]` + new_line = line.replace("]", "") + elif line.endswith("},") and not output.endswith("},"): + # problem, redundant char `}` + new_line = line.replace("}", "") + elif line.endswith("},") and output.endswith("},"): + new_line = line[:-1] + elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: + # problem, `"""` or `'''` without `,` + new_line = f",{line}" + elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]: + # backslash problem like \" in the output + char = rline[col_no - 1] + nearest_char_idx = rline[col_no:].find(char) + new_line = ( + rline[: col_no - 1] + + "\\" + + rline[col_no - 1 : col_no + nearest_char_idx] + + "\\" + + rline[col_no + nearest_char_idx :] + ) + elif '",' not in line and "," not in line and '"' not in line: + new_line = f'{line}",' + elif not line.endswith(","): + # problem, miss char `,` at the end. + new_line = f"{line}," + elif "," in line and len(line) == 1: + new_line = f'"{line}' + elif '",' in line: + new_line = line[:-2] + "'," + else: + new_line = line + + arr[line_no] = new_line + output = "\n".join(arr) + logger.info(f"repair_invalid_json, raw error: {error}") + + return output + + +def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]: + def run_and_passon(retry_state: RetryCallState) -> None: + """ + RetryCallState example + { + "start_time":143.098322024, + "retry_object":")>", + "fn":"", + "args":"(\"tag:[/CONTENT]\",)", # function input args + "kwargs":{}, # function input kwargs + "attempt_number":1, # retry number + "outcome":"", # type(outcome.result()) = "str", type(outcome.exception()) = "class" + "outcome_timestamp":143.098416904, + "idle_for":0, + "next_action":"None" + } + """ + if retry_state.outcome.failed: + if retry_state.args: + # # can't be used as args=retry_state.args + func_param_output = retry_state.args[0] + elif retry_state.kwargs: + func_param_output = retry_state.kwargs.get("output", "") + exp_str = str(retry_state.outcome.exception()) + + fix_str = "try to fix it, " if config.repair_llm_output else "" + logger.warning( + f"parse json from content inside [CONTENT][/CONTENT] failed at retry " + f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" + ) + + repaired_output = repair_invalid_json(func_param_output, exp_str) + retry_state.kwargs["output"] = repaired_output + + return run_and_passon + + +@retry( + stop=stop_after_attempt(3 if config.repair_llm_output else 0), + wait=wait_fixed(1), + after=run_after_exp_and_passon_next_retry(logger), +) +def retry_parse_json_text(output: str) -> Union[list, dict]: + """ + repair the json-text situation like there are extra chars like [']', '}'] + + Warning + if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work + if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times. + it's a two-layer retry cycle + """ + # logger.debug(f"output to json decode:\n{output}") + + # if CONFIG.repair_llm_output is True, it will try to fix output until the retry break + parsed_data = CustomDecoder(strict=False).decode(output) + + return parsed_data + + +def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"): + """extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern""" + + def re_extract_content(cont: str, pattern: str) -> str: + matches = re.findall(pattern, cont, re.DOTALL) + for match in matches: + if match: + cont = match + break + return cont.strip() + + # TODO construct the extract pattern with the `right_key` + raw_content = copy.deepcopy(content) + pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + + if not new_content.startswith("{"): + # TODO find a more general pattern + # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation + logger.warning(f"extract_content try another pattern: {pattern}") + if right_key not in new_content: + raw_content = copy.deepcopy(new_content + "\n" + right_key) + # # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]" + new_content = re_extract_content(raw_content, pattern) + else: + if right_key in new_content: + idx = new_content.find(right_key) + new_content = new_content[:idx] + new_content = new_content.strip() + + return new_content + + +def extract_state_value_from_output(content: str) -> str: + """ + For openai models, they will always return state number. But for open llm models, the instruction result maybe a + long text contain target number, so here add a extraction to improve success rate. + + Args: + content (str): llm's output from `Role._think` + """ + content = content.strip() # deal the output cases like " 0", "0\n" and so on. + pattern = ( + r"(? 0 else "-1" + return state diff --git a/notebook_dir/metagpt_yusin/utils/repo_to_markdown.py b/notebook_dir/metagpt_yusin/utils/repo_to_markdown.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6b2441a490d4e487c88ebc1bcdc574764e911e --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/repo_to_markdown.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +This file provides functionality to convert a local repository into a markdown representation. +""" +from __future__ import annotations + +import mimetypes +from pathlib import Path + +from gitignore_parser import parse_gitignore + +from metagpt_yusin.logs import logger +from metagpt_yusin.utils.common import aread, awrite, get_markdown_codeblock_type, list_files +from metagpt_yusin.utils.tree import tree + + +async def repo_to_markdown(repo_path: str | Path, output: str | Path = None, gitignore: str | Path = None) -> str: + """ + Convert a local repository into a markdown representation. + + This function takes a path to a local repository and generates a markdown representation of the repository structure, + including directory trees and file listings. + + Args: + repo_path (str | Path): The path to the local repository. + output (str | Path, optional): The path to save the generated markdown file. Defaults to None. + gitignore (str | Path, optional): The path to the .gitignore file. Defaults to None. + + Returns: + str: The markdown representation of the repository. + """ + repo_path = Path(repo_path) + gitignore = Path(gitignore or Path(__file__).parent / "../../.gitignore").resolve() + + markdown = await _write_dir_tree(repo_path=repo_path, gitignore=gitignore) + + gitignore_rules = parse_gitignore(full_path=str(gitignore)) + markdown += await _write_files(repo_path=repo_path, gitignore_rules=gitignore_rules) + + if output: + await awrite(filename=str(output), data=markdown, encoding="utf-8") + return markdown + + +async def _write_dir_tree(repo_path: Path, gitignore: Path) -> str: + try: + content = tree(repo_path, gitignore, run_command=True) + except Exception as e: + logger.info(f"{e}, using safe mode.") + content = tree(repo_path, gitignore, run_command=False) + + doc = f"## Directory Tree\n```text\n{content}\n```\n---\n\n" + return doc + + +async def _write_files(repo_path, gitignore_rules) -> str: + filenames = list_files(repo_path) + markdown = "" + for filename in filenames: + if gitignore_rules(str(filename)): + continue + markdown += await _write_file(filename=filename, repo_path=repo_path) + return markdown + + +async def _write_file(filename: Path, repo_path: Path) -> str: + relative_path = filename.relative_to(repo_path) + markdown = f"## {relative_path}\n" + + mime_type, _ = mimetypes.guess_type(filename.name) + if "text/" not in mime_type: + logger.info(f"Ignore content: {filename}") + markdown += "\n---\n\n" + return markdown + content = await aread(filename, encoding="utf-8") + content = content.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-") + code_block_type = get_markdown_codeblock_type(filename.name) + markdown += f"```{code_block_type}\n{content}\n```\n---\n\n" + return markdown diff --git a/notebook_dir/metagpt_yusin/utils/s3.py b/notebook_dir/metagpt_yusin/utils/s3.py new file mode 100644 index 0000000000000000000000000000000000000000..aac8c675067727f9571c8a468b61e441d7acf8d4 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/s3.py @@ -0,0 +1,154 @@ +import base64 +import os.path +import traceback +import uuid +from pathlib import Path +from typing import Optional + +import aioboto3 +import aiofiles + +from metagpt_yusin.config2 import S3Config +from metagpt_yusin.const import BASE64_FORMAT +from metagpt_yusin.logs import logger + + +class S3: + """A class for interacting with Amazon S3 storage.""" + + def __init__(self, config: S3Config): + self.session = aioboto3.Session() + self.config = config + self.auth_config = { + "service_name": "s3", + "aws_access_key_id": config.access_key, + "aws_secret_access_key": config.secret_key, + "endpoint_url": config.endpoint, + } + + async def upload_file( + self, + bucket: str, + local_path: str, + object_name: str, + ) -> None: + """Upload a file from the local path to the specified path of the storage bucket specified in s3. + + Args: + bucket: The name of the S3 storage bucket. + local_path: The local file path, including the file name. + object_name: The complete path of the uploaded file to be stored in S3, including the file name. + + Raises: + Exception: If an error occurs during the upload process, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + async with aiofiles.open(local_path, mode="rb") as reader: + body = await reader.read() + await client.put_object(Body=body, Bucket=bucket, Key=object_name) + logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.") + except Exception as e: + logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}") + raise e + + async def get_object_url( + self, + bucket: str, + object_name: str, + ) -> str: + """Get the URL for a downloadable or preview file stored in the specified S3 bucket. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + + Returns: + The URL for the downloadable or preview file. + + Raises: + Exception: If an error occurs while retrieving the URL, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + file = await client.get_object(Bucket=bucket, Key=object_name) + return str(file["Body"].url) + except Exception as e: + logger.error(f"Failed to get the url for a downloadable or preview file: {e}") + raise e + + async def get_object( + self, + bucket: str, + object_name: str, + ) -> bytes: + """Get the binary data of a file stored in the specified S3 bucket. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + + Returns: + The binary data of the requested file. + + Raises: + Exception: If an error occurs while retrieving the file data, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + s3_object = await client.get_object(Bucket=bucket, Key=object_name) + return await s3_object["Body"].read() + except Exception as e: + logger.error(f"Failed to get the binary data of the file: {e}") + raise e + + async def download_file( + self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024 + ) -> None: + """Download an S3 object to a local file. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + local_path: The local file path where the S3 object will be downloaded. + chunk_size: The size of data chunks to read and write at a time. Default is 128 KB. + + Raises: + Exception: If an error occurs during the download process, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + s3_object = await client.get_object(Bucket=bucket, Key=object_name) + stream = s3_object["Body"] + async with aiofiles.open(local_path, mode="wb") as writer: + while True: + file_data = await stream.read(chunk_size) + if not file_data: + break + await writer.write(file_data) + except Exception as e: + logger.error(f"Failed to download the file from S3: {e}") + raise e + + async def cache(self, data: str, file_ext: str, format: str = "") -> str: + """Save data to remote S3 and return url""" + object_name = uuid.uuid4().hex + file_ext + path = Path(__file__).parent + pathname = path / object_name + try: + async with aiofiles.open(str(pathname), mode="wb") as file: + data = base64.b64decode(data) if format == BASE64_FORMAT else data.encode(encoding="utf-8") + await file.write(data) + + bucket = self.config.bucket + object_pathname = self.config.bucket or "system" + object_pathname += f"/{object_name}" + object_pathname = os.path.normpath(object_pathname) + await self.upload_file(bucket=bucket, local_path=str(pathname), object_name=object_pathname) + pathname.unlink(missing_ok=True) + + return await self.get_object_url(bucket=bucket, object_name=object_pathname) + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") + pathname.unlink(missing_ok=True) + return None diff --git a/notebook_dir/metagpt_yusin/utils/save_code.py b/notebook_dir/metagpt_yusin/utils/save_code.py new file mode 100644 index 0000000000000000000000000000000000000000..c110ee4c9d857efa912920fb6b1ec64377a074c7 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/save_code.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# @Date : 12/12/2023 4:14 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import os + +import nbformat + +from metagpt_yusin.const import DATA_PATH +from metagpt_yusin.utils.common import write_json_file + + +def save_code_file(name: str, code_context: str, file_format: str = "py") -> None: + """ + Save code files to a specified path. + + Args: + - name (str): The name of the folder to save the files. + - code_context (str): The code content. + - file_format (str, optional): The file format. Supports 'py' (Python file), 'json' (JSON file), and 'ipynb' (Jupyter Notebook file). Default is 'py'. + + + Returns: + - None + """ + # Create the folder path if it doesn't exist + os.makedirs(name=DATA_PATH / "output" / f"{name}", exist_ok=True) + + # Choose to save as a Python file or a JSON file based on the file format + file_path = DATA_PATH / "output" / f"{name}/code.{file_format}" + if file_format == "py": + file_path.write_text(code_context + "\n\n", encoding="utf-8") + elif file_format == "json": + # Parse the code content as JSON and save + data = {"code": code_context} + write_json_file(file_path, data, encoding="utf-8", indent=2) + elif file_format == "ipynb": + nbformat.write(code_context, file_path) + else: + raise ValueError("Unsupported file format. Please choose 'py', 'json', or 'ipynb'.") diff --git a/notebook_dir/metagpt_yusin/utils/serialize.py b/notebook_dir/metagpt_yusin/utils/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..4034198a3da237dc047079f97df84e29e8bfcae2 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/serialize.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the implement of serialization and deserialization + +import copy +import pickle + +from metagpt_yusin.utils.common import import_class + + +def actionoutout_schema_to_mapping(schema: dict) -> dict: + """ + directly traverse the `properties` in the first level. + schema structure likes + ``` + { + "title":"prd", + "type":"object", + "properties":{ + "Original Requirements":{ + "title":"Original Requirements", + "type":"string" + }, + }, + "required":[ + "Original Requirements", + ] + } + ``` + """ + mapping = dict() + for field, property in schema["properties"].items(): + if property["type"] == "string": + mapping[field] = (str, ...) + elif property["type"] == "array" and property["items"]["type"] == "string": + mapping[field] = (list[str], ...) + elif property["type"] == "array" and property["items"]["type"] == "array": + # here only consider the `list[list[str]]` situation + mapping[field] = (list[list[str]], ...) + return mapping + + +def actionoutput_mapping_to_str(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + new_mapping[key] = str(value) + return new_mapping + + +def actionoutput_str_to_mapping(mapping: dict) -> dict: + new_mapping = {} + for key, value in mapping.items(): + if value == "(, Ellipsis)": + new_mapping[key] = (str, ...) + else: + new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)` + return new_mapping + + +def serialize_message(message: "Message"): + message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference + ic = message_cp.instruct_content + if ic: + # model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly + schema = ic.model_json_schema() + mapping = actionoutout_schema_to_mapping(schema) + + message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()} + msg_ser = pickle.dumps(message_cp) + + return msg_ser + + +def deserialize_message(message_ser: str) -> "Message": + message = pickle.loads(message_ser) + if message.instruct_content: + ic = message.instruct_content + actionnode_class = import_class("ActionNode", "metagpt_yusin.actions.action_node") # avoid circular import + ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"]) + ic_new = ic_obj(**ic["value"]) + message.instruct_content = ic_new + + return message diff --git a/notebook_dir/metagpt_yusin/utils/singleton.py b/notebook_dir/metagpt_yusin/utils/singleton.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e0862c050777981a753fa3f6449578f07e737c --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/singleton.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 16:15 +@Author : alexanderwu +@File : singleton.py +""" +import abc + + +class Singleton(abc.ABCMeta, type): + """ + Singleton metaclass for ensuring only one instance of a class. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """Call method for the singleton metaclass.""" + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] diff --git a/notebook_dir/metagpt_yusin/utils/special_tokens.py b/notebook_dir/metagpt_yusin/utils/special_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..5e780ce0524738efbbb441be13fb80e25a03909b --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/special_tokens.py @@ -0,0 +1,4 @@ +# token to separate different code messages in a WriteCode Message content +MSG_SEP = "#*000*#" +# token to seperate file name and the actual code text in a code message +FILENAME_CODE_SEP = "#*001*#" diff --git a/notebook_dir/metagpt_yusin/utils/stream_pipe.py b/notebook_dir/metagpt_yusin/utils/stream_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d3cff324f09ab7b285dc4a3a55bc4b991a42a5 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/stream_pipe.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/3/27 10:00 +# @Author : leiwu30 +# @File : stream_pipe.py +# @Version : None +# @Description : None + +import time +import json +from multiprocessing import Pipe + + +class StreamPipe: + parent_conn, child_conn = Pipe() + finish: bool = False + + format_data = { + "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", + "object": "chat.completion.chunk", + "created": 1711361191, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_3bc1b5746c", + "choices": [ + { + "index": 0, + "delta": + { + "role": "assistant", + "content": "content" + }, + "logprobs": None, + "finish_reason": None + } + ] + } + + def set_message(self, msg): + self.parent_conn.send(msg) + + def get_message(self, timeout: int = 3): + if self.child_conn.poll(timeout): + return self.child_conn.recv() + else: + return None + + def msg2stream(self, msg): + self.format_data['created'] = int(time.time()) + self.format_data['choices'][0]['delta']['content'] = msg + return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8") diff --git a/notebook_dir/metagpt_yusin/utils/text.py b/notebook_dir/metagpt_yusin/utils/text.py new file mode 100644 index 0000000000000000000000000000000000000000..36ef4b72d3013cfc8ada79a359561b0c613f8fe8 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/text.py @@ -0,0 +1,129 @@ +from typing import Generator, Sequence + +from metagpt_yusin.utils.token_counter import TOKEN_MAX, count_string_tokens + + +def reduce_message_length( + msgs: Generator[str, None, None], + model_name: str, + system_text: str, + reserved: int = 0, +) -> str: + """Reduce the length of concatenated message segments to fit within the maximum token size. + + Args: + msgs: A generator of strings representing progressively shorter valid prompts. + model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo") + system_text: The system prompts. + reserved: The number of reserved tokens. + + Returns: + The concatenated message segments reduced to fit within the maximum token size. + + Raises: + RuntimeError: If it fails to reduce the concatenated message length. + """ + max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved + for msg in msgs: + if count_string_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX: + return msg + + raise RuntimeError("fail to reduce message length") + + +def generate_prompt_chunk( + text: str, + prompt_template: str, + model_name: str, + system_text: str, + reserved: int = 0, +) -> Generator[str, None, None]: + """Split the text into chunks of a maximum token size. + + Args: + text: The text to split. + prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}". + model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo") + system_text: The system prompts. + reserved: The number of reserved tokens. + + Yields: + The chunk of text. + """ + paragraphs = text.splitlines(keepends=True) + current_token = 0 + current_lines = [] + + reserved = reserved + count_string_tokens(prompt_template + system_text, model_name) + # 100 is a magic number to ensure the maximum context length is not exceeded + max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100 + + while paragraphs: + paragraph = paragraphs.pop(0) + token = count_string_tokens(paragraph, model_name) + if current_token + token <= max_token: + current_lines.append(paragraph) + current_token += token + elif token > max_token: + paragraphs = split_paragraph(paragraph) + paragraphs + continue + else: + yield prompt_template.format("".join(current_lines)) + current_lines = [paragraph] + current_token = token + + if current_lines: + yield prompt_template.format("".join(current_lines)) + + +def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]: + """Split a paragraph into multiple parts. + + Args: + paragraph: The paragraph to split. + sep: The separator character. + count: The number of parts to split the paragraph into. + + Returns: + A list of split parts of the paragraph. + """ + for i in sep: + sentences = list(_split_text_with_ends(paragraph, i)) + if len(sentences) <= 1: + continue + ret = ["".join(j) for j in _split_by_count(sentences, count)] + return ret + return list(_split_by_count(paragraph, count)) + + +def decode_unicode_escape(text: str) -> str: + """Decode a text with unicode escape sequences. + + Args: + text: The text to decode. + + Returns: + The decoded text. + """ + return text.encode("utf-8").decode("unicode_escape", "ignore") + + +def _split_by_count(lst: Sequence, count: int): + avg = len(lst) // count + remainder = len(lst) % count + start = 0 + for i in range(count): + end = start + avg + (1 if i < remainder else 0) + yield lst[start:end] + start = end + + +def _split_text_with_ends(text: str, sep: str = "."): + parts = [] + for i in text: + parts.append(i) + if i == sep: + yield "".join(parts) + parts = [] + if parts: + yield "".join(parts) diff --git a/notebook_dir/metagpt_yusin/utils/token_counter.py b/notebook_dir/metagpt_yusin/utils/token_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba2daa893886e12e7685627435ded7ac31e2a40 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/token_counter.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/18 00:40 +@Author : alexanderwu +@File : token_counter.py +ref1: https://openai.com/pricing +ref2: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb +ref3: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py +ref4: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref5: https://ai.google.dev/models/gemini +""" +import tiktoken + +TOKEN_COSTS = { + "gpt-3.5-turbo": {"prompt": 0.0015, "completion": 0.002}, + "gpt-3.5-turbo-0301": {"prompt": 0.0015, "completion": 0.002}, + "gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002}, + "gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004}, + "gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004}, + "gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002}, + "gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004}, + "gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002}, + "gpt-3.5-turbo-0125": {"prompt": 0.001, "completion": 0.002}, + "gpt-4-0314": {"prompt": 0.03, "completion": 0.06}, + "gpt-4": {"prompt": 0.03, "completion": 0.06}, + "gpt-4-32k": {"prompt": 0.06, "completion": 0.12}, + "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, + "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, + "gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator + "gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03}, + "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, + "glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens + "glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, + "moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens + "moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024}, + "moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06}, + "open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025}, + "open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007}, + "mistral-small-latest": {"prompt": 0.002, "completion": 0.006}, + "mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081}, + "mistral-large-latest": {"prompt": 0.008, "completion": 0.024}, + "claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024}, + "claude-2.0": {"prompt": 0.008, "completion": 0.024}, + "claude-2.1": {"prompt": 0.008, "completion": 0.024}, + "claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015}, + "claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075}, + "yi-34b-chat-0205": {"prompt": 0.0003, "completion": 0.0003}, + "yi-34b-chat-200k": {"prompt": 0.0017, "completion": 0.0017}, +} + + +""" +QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 +Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method. +""" +QIANFAN_MODEL_TOKEN_COSTS = { + "ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017}, + "ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067}, + "ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017}, + "ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011}, + "EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011}, + "ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011}, + "BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056}, + "Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084}, + "Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049}, + "ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056}, + "AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049}, + "SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056}, + "CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056}, + "XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049}, + "Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056}, + "Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084}, + "ChatLaw": {"prompt": 0.0011, "completion": 0.0011}, + "Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0}, +} + +QIANFAN_ENDPOINT_TOKEN_COSTS = { + "completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"], + "ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"], + "completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"], + "eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"], + "ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"], + "ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"], + "bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"], + "llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"], + "llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"], + "llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"], + "chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"], + "aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"], + "mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"], + "sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"], + "codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"], + "xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"], + "qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"], + "qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"], + "qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"], + "chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"], + "yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"], +} + +""" +DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing +Different model has different detail page. Attention, some model are free for a limited time. +""" +DASHSCOPE_TOKEN_COSTS = { + "qwen-turbo": {"prompt": 0.0011, "completion": 0.0011}, + "qwen-plus": {"prompt": 0.0028, "completion": 0.0028}, + "qwen-max": {"prompt": 0.0, "completion": 0.0}, + "qwen-max-1201": {"prompt": 0.0, "completion": 0.0}, + "qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0}, + "llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0}, + "llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0}, + "qwen-72b-chat": {"prompt": 0.0, "completion": 0.0}, + "qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011}, + "qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084}, + "qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0}, + "baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011}, + "baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084}, + "baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0}, + "chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011}, + "chatglm3-6b": {"prompt": 0.0, "completion": 0.0}, + "ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free + "dolly-12b-v2": {"prompt": 0.0, "completion": 0.0}, + "belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0}, + "moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0}, + "chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0}, + "billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0}, +} + + +FIREWORKS_GRADE_TOKEN_COSTS = { + "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition + "16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens + "80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B + "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, +} + +# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo +TOKEN_MAX = { + "gpt-4-0125-preview": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4-1106-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5-turbo": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k-0613": 16385, + "text-embedding-ada-002": 8192, + "glm-3-turbo": 128000, + "glm-4": 128000, + "gemini-pro": 32768, + "moonshot-v1-8k": 8192, + "moonshot-v1-32k": 32768, + "moonshot-v1-128k": 128000, + "open-mistral-7b": 8192, + "open-mixtral-8x7b": 32768, + "mistral-small-latest": 32768, + "mistral-medium-latest": 32768, + "mistral-large-latest": 32768, + "claude-instant-1.2": 100000, + "claude-2.0": 100000, + "claude-2.1": 200000, + "claude-3-sonnet-20240229": 200000, + "claude-3-opus-20240229": 200000, + "yi-34b-chat-0205": 4000, + "yi-34b-chat-200k": 200000, +} + + +def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): + """Return the number of tokens used by a list of messages.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-35-turbo", + "gpt-35-turbo-16k", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4-1106-vision-preview", + }: + tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|> + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = -1 # if there's a name, the role is omitted + elif "gpt-3.5-turbo" == model: + print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.") + return count_message_tokens(messages, model="gpt-3.5-turbo-0125") + elif "gpt-4" == model: + print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return count_message_tokens(messages, model="gpt-4-0613") + elif "open-llm-model" == model: + """ + For self-hosted open_llm api, they include lots of different models. The message tokens calculation is + inaccurate. It's a reference result. + """ + tokens_per_message = 0 # ignore conversation message template prefix + tokens_per_name = 0 + else: + raise NotImplementedError( + f"num_tokens_from_messages() is not implemented for model {model}. " + f"See https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken " + f"for information on how messages are converted to tokens." + ) + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + content = value + if isinstance(value, list): + # for gpt-4v + for item in value: + if isinstance(item, dict) and item.get("type") in ["text"]: + content = item.get("text", "") + num_tokens += len(encoding.encode(content)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> + return num_tokens + + +def count_string_tokens(string: str, model_name: str) -> int: + """ + Returns the number of tokens in a text string. + + Args: + string (str): The text string. + model_name (str): The name of the encoding to use. (e.g., "gpt-3.5-turbo") + + Returns: + int: The number of tokens in the text string. + """ + try: + encoding = tiktoken.encoding_for_model(model_name) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(string)) + + +def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int: + """Calculate the maximum number of completion tokens for a given model and list of messages. + + Args: + messages: A list of messages. + model: The model name. + + Returns: + The maximum number of completion tokens. + """ + if model not in TOKEN_MAX: + return default + return TOKEN_MAX[model] - count_message_tokens(messages) - 1 diff --git a/notebook_dir/metagpt_yusin/utils/tree.py b/notebook_dir/metagpt_yusin/utils/tree.py new file mode 100644 index 0000000000000000000000000000000000000000..bd792229010d96a74c08b29c0e9c04030c04a1ce --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/tree.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/11 +@Author : mashenquan +@File : tree.py +@Desc : Implement the same functionality as the `tree` command. + Example: + >>> print_tree(".") + utils + +-- serialize.py + +-- project_repo.py + +-- tree.py + +-- mmdc_playwright.py + +-- cost_manager.py + +-- __pycache__ + | +-- __init__.cpython-39.pyc + | +-- redis.cpython-39.pyc + | +-- singleton.cpython-39.pyc + | +-- embedding.cpython-39.pyc + | +-- make_sk_kernel.cpython-39.pyc + | +-- file_repository.cpython-39.pyc + +-- file.py + +-- save_code.py + +-- common.py + +-- redis.py +""" +from __future__ import annotations + +import subprocess +from pathlib import Path +from typing import Callable, Dict, List + +from gitignore_parser import parse_gitignore + + +def tree(root: str | Path, gitignore: str | Path = None, run_command: bool = False) -> str: + """ + Recursively traverses the directory structure and prints it out in a tree-like format. + + Args: + root (str or Path): The root directory from which to start traversing. + gitignore (str or Path): The filename of gitignore file. + run_command (bool): Whether to execute `tree` command. Execute the `tree` command and return the result if True, + otherwise execute python code instead. + + Returns: + str: A string representation of the directory tree. + + Example: + >>> tree(".") + utils + +-- serialize.py + +-- project_repo.py + +-- tree.py + +-- mmdc_playwright.py + +-- __pycache__ + | +-- __init__.cpython-39.pyc + | +-- redis.cpython-39.pyc + | +-- singleton.cpython-39.pyc + +-- parse_docstring.py + + >>> tree(".", gitignore="../../.gitignore") + utils + +-- serialize.py + +-- project_repo.py + +-- tree.py + +-- mmdc_playwright.py + +-- parse_docstring.py + + >>> tree(".", gitignore="../../.gitignore", run_command=True) + utils + ├── serialize.py + ├── project_repo.py + ├── tree.py + ├── mmdc_playwright.py + └── parse_docstring.py + + + """ + root = Path(root).resolve() + if run_command: + return _execute_tree(root, gitignore) + + git_ignore_rules = parse_gitignore(gitignore) if gitignore else None + dir_ = {root.name: _list_children(root=root, git_ignore_rules=git_ignore_rules)} + v = _print_tree(dir_) + return "\n".join(v) + + +def _list_children(root: Path, git_ignore_rules: Callable) -> Dict[str, Dict]: + dir_ = {} + for i in root.iterdir(): + if git_ignore_rules and git_ignore_rules(str(i)): + continue + try: + if i.is_file(): + dir_[i.name] = {} + else: + dir_[i.name] = _list_children(root=i, git_ignore_rules=git_ignore_rules) + except (FileNotFoundError, PermissionError, OSError): + dir_[i.name] = {} + return dir_ + + +def _print_tree(dir_: Dict[str:Dict]) -> List[str]: + ret = [] + for name, children in dir_.items(): + ret.append(name) + if not children: + continue + lines = _print_tree(children) + for j, v in enumerate(lines): + if v[0] not in ["+", " ", "|"]: + ret = _add_line(ret) + row = f"+-- {v}" + else: + row = f" {v}" + ret.append(row) + return ret + + +def _add_line(rows: List[str]) -> List[str]: + for i in range(len(rows) - 1, -1, -1): + v = rows[i] + if v[0] != " ": + return rows + rows[i] = "|" + v[1:] + return rows + + +def _execute_tree(root: Path, gitignore: str | Path) -> str: + args = ["--gitfile", str(gitignore)] if gitignore else [] + try: + result = subprocess.run(["tree"] + args + [str(root)], capture_output=True, text=True, check=True) + if result.returncode != 0: + raise ValueError(f"tree exits with code {result.returncode}") + return result.stdout + except subprocess.CalledProcessError as e: + raise e diff --git a/notebook_dir/metagpt_yusin/utils/visual_graph_repo.py b/notebook_dir/metagpt_yusin/utils/visual_graph_repo.py new file mode 100644 index 0000000000000000000000000000000000000000..8f97fdbe37251ea35fc73fb652eac6f110166673 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/visual_graph_repo.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : visualize_graph.py +@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository. +""" +from __future__ import annotations + +import re +from abc import ABC +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel, Field + +from metagpt_yusin.const import AGGREGATION, COMPOSITION, GENERALIZATION +from metagpt_yusin.schema import UMLClassView +from metagpt_yusin.utils.common import split_namespace +from metagpt_yusin.utils.di_graph_repository import DiGraphRepository +from metagpt_yusin.utils.graph_repository import GraphKeyword, GraphRepository + + +class _VisualClassView(BaseModel): + """Protected class used by VisualGraphRepo internally. + + Attributes: + package (str): The package associated with the class. + uml (Optional[UMLClassView]): Optional UMLClassView associated with the class. + generalizations (List[str]): List of generalizations for the class. + compositions (List[str]): List of compositions for the class. + aggregations (List[str]): List of aggregations for the class. + """ + + package: str + uml: Optional[UMLClassView] = None + generalizations: List[str] = Field(default_factory=list) + compositions: List[str] = Field(default_factory=list) + aggregations: List[str] = Field(default_factory=list) + + def get_mermaid(self, align: int = 1) -> str: + """Creates a Markdown Mermaid class diagram text. + + Args: + align (int): Indent count used for alignment. + + Returns: + str: The Markdown text representing the Mermaid class diagram. + """ + if not self.uml: + return "" + prefix = "\t" * align + + mermaid_txt = self.uml.get_mermaid(align=align) + for i in self.generalizations: + mermaid_txt += f"{prefix}{i} <|-- {self.name}\n" + for i in self.compositions: + mermaid_txt += f"{prefix}{i} *-- {self.name}\n" + for i in self.aggregations: + mermaid_txt += f"{prefix}{i} o-- {self.name}\n" + return mermaid_txt + + @property + def name(self) -> str: + """Returns the class name without the namespace prefix.""" + return split_namespace(self.package)[-1] + + +class VisualGraphRepo(ABC): + """Abstract base class for VisualGraphRepo.""" + + graph_db: GraphRepository + + def __init__(self, graph_db): + self.graph_db = graph_db + + +class VisualDiGraphRepo(VisualGraphRepo): + """Implementation of VisualGraphRepo for DiGraph graph repository. + + This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph. + """ + + @classmethod + async def load_from(cls, filename: str | Path): + """Load a VisualDiGraphRepo instance from a file.""" + graph_db = await DiGraphRepository.load_from(str(filename)) + return cls(graph_db=graph_db) + + async def get_mermaid_class_view(self) -> str: + """ + Returns a Markdown Mermaid class diagram code block object. + """ + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + mermaid_txt = "classDiagram\n" + for r in rows: + v = await self._get_class_view(ns_class_name=r.subject) + mermaid_txt += v.get_mermaid() + return mermaid_txt + + async def _get_class_view(self, ns_class_name: str) -> _VisualClassView: + """Returns the Markdown Mermaid class diagram code block object for the specified class.""" + rows = await self.graph_db.select(subject=ns_class_name) + class_view = _VisualClassView(package=ns_class_name) + for r in rows: + if r.predicate == GraphKeyword.HAS_CLASS_VIEW: + class_view.uml = UMLClassView.model_validate_json(r.object_) + elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF: + name = split_namespace(r.object_)[-1] + name = self._refine_name(name) + if name: + class_view.generalizations.append(name) + elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF: + name = split_namespace(r.object_)[-1] + name = self._refine_name(name) + if name: + class_view.compositions.append(name) + elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF: + name = split_namespace(r.object_)[-1] + name = self._refine_name(name) + if name: + class_view.aggregations.append(name) + return class_view + + async def get_mermaid_sequence_views(self) -> List[(str, str)]: + """Returns all Markdown sequence diagrams with their corresponding graph repository keys.""" + sequence_views = [] + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + for r in rows: + sequence_views.append((r.subject, r.object_)) + return sequence_views + + @staticmethod + def _refine_name(name: str) -> str: + """Removes impurity content from the given name. + + Example: + >>> _refine_name("int") + "" + + >>> _refine_name('"Class1"') + 'Class1' + + >>> _refine_name("pkg.Class1") + "Class1" + """ + name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name) + if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]: + return "" + if "." in name: + name = name.split(".")[-1] + + return name + + async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]: + """Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys.""" + sequence_views = [] + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER) + for r in rows: + sequence_views.append((r.subject, r.object_)) + return sequence_views diff --git a/notebook_dir/metagpt_yusin/utils/yaml_model.py b/notebook_dir/metagpt_yusin/utils/yaml_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4d42bb03fe21df3a64573693a1c4200f06207de8 --- /dev/null +++ b/notebook_dir/metagpt_yusin/utils/yaml_model.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 10:18 +@Author : alexanderwu +@File : YamlModel.py +""" +from pathlib import Path +from typing import Dict, Optional + +import yaml +from pydantic import BaseModel, model_validator + + +class YamlModel(BaseModel): + """Base class for yaml model""" + + extra_fields: Optional[Dict[str, str]] = None + + @classmethod + def read_yaml(cls, file_path: Path, encoding: str = "utf-8") -> Dict: + """Read yaml file and return a dict""" + if not file_path.exists(): + return {} + with open(file_path, "r", encoding=encoding) as file: + return yaml.safe_load(file) + + @classmethod + def from_yaml_file(cls, file_path: Path) -> "YamlModel": + """Read yaml file and return a YamlModel instance""" + return cls(**cls.read_yaml(file_path)) + + def to_yaml_file(self, file_path: Path, encoding: str = "utf-8") -> None: + """Dump YamlModel instance to yaml file""" + with open(file_path, "w", encoding=encoding) as file: + yaml.dump(self.model_dump(), file) + + +class YamlModelWithoutDefault(YamlModel): + """YamlModel without default values""" + + @model_validator(mode="before") + @classmethod + def check_not_default_config(cls, values): + """Check if there is any default config in config2.yaml""" + if any(["YOUR" in v for v in values]): + raise ValueError("Please set your config in config2.yaml") + return values