Spaces:
Sleeping
Sleeping
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from typing import Any, List, Optional, Dict | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from pydantic import Field, PrivateAttr | |
class HuggingFaceLLM(LLM): | |
model_id: str = Field(..., description="Hugging Face model ID") | |
temperature: float = Field(default=0.7, description="Sampling temperature") | |
max_tokens: int = Field(default=256, description="Maximum number of tokens to generate") | |
device: str = Field(default="cpu", description="Device to run the model on") | |
_model: Optional[Any] = PrivateAttr(default=None) | |
_tokenizer: Optional[Any] = PrivateAttr(default=None) | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.device = "cuda" if torch.cuda.is_available() and self.device != "cpu" else "cpu" | |
self._load_model() | |
def _load_model(self): | |
self._tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
self._model = AutoModelForCausalLM.from_pretrained(self.model_id) | |
self._model = self._model.to(torch.device(self.device)) | |
def _llm_type(self) -> str: | |
return "custom_huggingface" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
input_ids = self._tokenizer.encode(prompt, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
output = self._model.generate( | |
input_ids, | |
max_new_tokens=self.max_tokens, | |
temperature=self.temperature, | |
do_sample=True, | |
pad_token_id=self._tokenizer.eos_token_id | |
) | |
response = self._tokenizer.decode(output[0], skip_special_tokens=True) | |
return response[len(prompt):].strip() | |
def _identifying_params(self) -> Dict[str, Any]: | |
return {"model_id": self.model_id, "temperature": self.temperature, "max_tokens": self.max_tokens, "device": self.device} | |
def __setattr__(self, name, value): | |
if name in ["_model", "_tokenizer"]: | |
object.__setattr__(self, name, value) | |
else: | |
super().__setattr__(name, value) |