Spaces:
Runtime error
Runtime error
File size: 4,185 Bytes
28a47b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from pathlib import Path
from typing import Any, Dict, List, Type
from pydantic import BaseModel, BaseSettings, Extra
import os
def json_config_settings_source(settings: BaseSettings) -> Dict[str, Any]:
from util import load_json
# Load settings from JSON config file
config_dir = Path(os.getcwd(), ".suspicionagent")
config_file = Path(config_dir, "config.json")
if not config_dir.exists() or not config_file.exists():
print("[Error] Please config suspicionagent")
import sys
sys.exit(-1)
return load_json(config_file)
class LLMSettings(BaseModel):
"""
LLM/ChatModel related settings
"""
type: str = "chatopenai"
class Config:
extra = Extra.allow
class EmbeddingSettings(BaseModel):
"""
Embedding related settings
"""
type: str = "openaiembeddings"
class Config:
extra = Extra.allow
class ModelSettings(BaseModel):
"""
Model related settings
"""
type: str = ""
llm: LLMSettings = LLMSettings()
embedding: EmbeddingSettings = EmbeddingSettings()
class Config:
extra = Extra.allow
class Settings(BaseSettings):
"""
Root settings
"""
name: str = "default"
model: ModelSettings = ModelSettings()
class Config:
env_prefix = "skyagi_"
env_file_encoding = "utf-8"
extra = Extra.allow
@classmethod
def customise_sources(
cls,
init_settings,
env_settings,
file_secret_settings,
):
return (
init_settings,
#json_config_settings_source,
env_settings,
file_secret_settings,
)
# ---------------------------------------------------------------------------- #
# Preset configurations #
# ---------------------------------------------------------------------------- #
class OpenAIGPT4Settings(ModelSettings):
# NOTE: GPT4 is in waitlist
type = "openai-gpt-4-0613"
llm = LLMSettings(type="chatopenai", model="gpt-4-0613", max_tokens=3000,temperature=0.1, request_timeout=120)
embedding = EmbeddingSettings(type="openaiembeddings")
class OpenAIGPT432kSettings(ModelSettings):
# NOTE: GPT4 is in waitlist
type = "openai-gpt-4-32k-0613"
llm = LLMSettings(type="chatopenai", model="gpt-4-32k-0613", max_tokens=2500)
embedding = EmbeddingSettings(type="openaiembeddings")
class OpenAIGPT3_5TurboSettings(ModelSettings):
type = "openai-gpt-3.5-turbo"
llm = LLMSettings(type="chatopenai", model="gpt-3.5-turbo-16k-0613", max_tokens=2500)
embedding = EmbeddingSettings(type="openaiembeddings")
class OpenAIGPT3_5TextDavinci003Settings(ModelSettings):
type = "openai-gpt-3.5-text-davinci-003"
llm = LLMSettings(type="openai", model_name="text-davinci-003", max_tokens=2500)
embedding = EmbeddingSettings(type="openaiembeddings")
# class Llama2_70b_Settings(ModelSettings):
# from transformers import LlamaForCausalLM, LlamaTokenizer
# type = "llama2-70b"
# tokenizer = LlamaTokenizer.from_pretrained("/groups/gcb50389/pretrained/llama2-HF/Llama-2-70b-hf")
# llm = LlamaForCausalLM.from_pretrained("/groups/gcb50389/pretrained/llama2-HF/Llama-2-70b-hf")
# embedding = EmbeddingSettings(type="openaiembeddings")
# ------------------------- Model settings registry ------------------------ #
model_setting_type_to_cls_dict: Dict[str, Type[ModelSettings]] = {
"openai-gpt-4-0613": OpenAIGPT4Settings,
"openai-gpt-4-32k-0613": OpenAIGPT432kSettings,
"openai-gpt-3.5-turbo": OpenAIGPT3_5TurboSettings,
"openai-gpt-3.5-text-davinci-003": OpenAIGPT3_5TextDavinci003Settings,
# "llama2-70b":Llama2_70b_Settings
}
def load_model_setting(type: str) -> ModelSettings:
if type not in model_setting_type_to_cls_dict:
raise ValueError(f"Loading {type} setting not supported")
cls = model_setting_type_to_cls_dict[type]
return cls()
def get_all_model_settings() -> List[str]:
"""Get all supported Embeddings"""
return list(model_setting_type_to_cls_dict.keys())
|