muryshev's picture
Добавлен токенизатор для корректной обрезки запроса.
2ccde67
raw
history blame
3.18 kB
from pydantic import BaseModel, Field
from typing import Optional, List, Protocol
from abc import ABC, abstractmethod
class LlmPredictParams(BaseModel):
"""
Параметры для предсказания LLM.
"""
system_prompt: Optional[str] = Field(None, description="Системный промпт.")
user_prompt: Optional[str] = Field(None, description="Шаблон промпта для передачи от роли user.")
n_predict: Optional[int] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
seed: Optional[int] = None
repeat_penalty: Optional[float] = None
repeat_last_n: Optional[int] = None
retry_if_text_not_present: Optional[str] = None
retry_count: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
n_keep: Optional[int] = None
cache_prompt: Optional[bool] = None
stop: Optional[List[str]] = None
class LlmParams(BaseModel):
"""
Основные параметры для LLM.
"""
url: str
model: Optional[str] = Field(None, description="Предполагается, что для локального API этот параметр не будет указываться, т.к. будем брать первую модель из списка потому, что модель доступна всего одна. Для deepinfra такой подход не подойдет и модель нужно задавать явно.")
tokenizer: Optional[str] = Field(None, description="При использовании стороннего API, не поддерживающего токенизацию, будет использован AutoTokenizer для модели из этого поля. Используется в случае, если название модели в API не совпадает с оригинальным названием на Huggingface.")
type: Optional[str] = None
default: Optional[bool] = None
template: Optional[str] = None
predict_params: Optional[LlmPredictParams] = None
api_key: Optional[str] = None
context_length: Optional[int] = None
class LlmApiProtocol(Protocol):
async def tokenize(self, prompt: str) -> Optional[dict]:
...
async def detokenize(self, tokens: List[int]) -> Optional[str]:
...
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
...
async def predict(self, prompt: str) -> str:
...
class LlmApi:
"""
Базовый клас для работы с API LLM.
"""
params: LlmParams = None
def __init__(self):
self.params = None
def set_params(self, params: LlmParams):
self.params = params
def create_headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self.params.api_key is not None:
headers["Authorization"] = self.params.api_key
return headers