Boris Shapkin commited on
Commit
0d17078
·
1 Parent(s): 5018bdb

additional files

Browse files
Files changed (1) hide show
  1. huggingface_llm.py +51 -0
huggingface_llm.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.llms.base import LLM
2
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
3
+ from typing import Any, List, Optional, Dict
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import torch
6
+ from pydantic import Field
7
+
8
+ class HuggingFaceLLM(LLM):
9
+ model_id: str = Field(..., description="Hugging Face model ID")
10
+ model: Any = Field(default=None, exclude=True)
11
+ tokenizer: Any = Field(default=None, exclude=True)
12
+ device: str = Field(default="cuda" if torch.cuda.is_available() else "cpu")
13
+ temperature: float = Field(default=0.7, description="Sampling temperature")
14
+ max_tokens: int = Field(default=256, description="Maximum number of tokens to generate")
15
+
16
+ def __init__(self, **kwargs):
17
+ super().__init__(**kwargs)
18
+ self._load_model()
19
+
20
+ def _load_model(self):
21
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
22
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id).to(self.device)
23
+
24
+ @property
25
+ def _llm_type(self) -> str:
26
+ return "custom_huggingface"
27
+
28
+ def _call(
29
+ self,
30
+ prompt: str,
31
+ stop: Optional[List[str]] = None,
32
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
33
+ **kwargs: Any,
34
+ ) -> str:
35
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
36
+
37
+ with torch.no_grad():
38
+ output = self.model.generate(
39
+ input_ids,
40
+ max_new_tokens=self.max_tokens,
41
+ temperature=self.temperature,
42
+ do_sample=True,
43
+ pad_token_id=self.tokenizer.eos_token_id
44
+ )
45
+
46
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
47
+ return response[len(prompt):].strip()
48
+
49
+ @property
50
+ def _identifying_params(self) -> Dict[str, Any]:
51
+ return {"model_id": self.model_id, "temperature": self.temperature, "max_tokens": self.max_tokens}