chatbot_ohw_projects / huggingface_llm.py
boryasbora's picture
Update huggingface_llm.py
9b860ac verified
raw
history blame
2.34 kB
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from typing import Any, List, Optional, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pydantic import Field, PrivateAttr
class HuggingFaceLLM(LLM):
model_id: str = Field(..., description="Hugging Face model ID")
temperature: float = Field(default=0.7, description="Sampling temperature")
max_tokens: int = Field(default=256, description="Maximum number of tokens to generate")
device: str = Field(default="cpu", description="Device to run the model on")
_model: Optional[Any] = PrivateAttr(default=None)
_tokenizer: Optional[Any] = PrivateAttr(default=None)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.device = "cuda" if torch.cuda.is_available() and self.device != "cpu" else "cpu"
self._load_model()
def _load_model(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self._model = AutoModelForCausalLM.from_pretrained(self.model_id)
self._model = self._model.to(torch.device(self.device))
@property
def _llm_type(self) -> str:
return "custom_huggingface"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
input_ids = self._tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self._model.generate(
input_ids,
max_new_tokens=self.max_tokens,
temperature=self.temperature,
do_sample=True,
pad_token_id=self._tokenizer.eos_token_id
)
response = self._tokenizer.decode(output[0], skip_special_tokens=True)
return response[len(prompt):].strip()
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"model_id": self.model_id, "temperature": self.temperature, "max_tokens": self.max_tokens, "device": self.device}
def __setattr__(self, name, value):
if name in ["_model", "_tokenizer"]:
object.__setattr__(self, name, value)
else:
super().__setattr__(name, value)