boryasbora commited on
Commit
5cbe372
·
verified ·
1 Parent(s): 361560e

Update huggingface_llm.py

Browse files
Files changed (1) hide show
  1. huggingface_llm.py +4 -2
huggingface_llm.py CHANGED
@@ -9,12 +9,14 @@ class HuggingFaceLLM(LLM):
9
  model_id: str = Field(..., description="Hugging Face model ID")
10
  model: Any = Field(default=None, exclude=True)
11
  tokenizer: Any = Field(default=None, exclude=True)
12
- device: str = Field(default="cuda" if torch.cuda.is_available() else "cpu")
13
  temperature: float = Field(default=0.7, description="Sampling temperature")
14
  max_tokens: int = Field(default=256, description="Maximum number of tokens to generate")
 
15
 
16
  def __init__(self, **kwargs):
17
  super().__init__(**kwargs)
 
 
18
  self._load_model()
19
 
20
  def _load_model(self):
@@ -48,4 +50,4 @@ class HuggingFaceLLM(LLM):
48
 
49
  @property
50
  def _identifying_params(self) -> Dict[str, Any]:
51
- return {"model_id": self.model_id, "temperature": self.temperature, "max_tokens": self.max_tokens}
 
9
  model_id: str = Field(..., description="Hugging Face model ID")
10
  model: Any = Field(default=None, exclude=True)
11
  tokenizer: Any = Field(default=None, exclude=True)
 
12
  temperature: float = Field(default=0.7, description="Sampling temperature")
13
  max_tokens: int = Field(default=256, description="Maximum number of tokens to generate")
14
+ device: str = Field(default=None, description="Device to run the model on")
15
 
16
  def __init__(self, **kwargs):
17
  super().__init__(**kwargs)
18
+ if self.device is None:
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
  self._load_model()
21
 
22
  def _load_model(self):
 
50
 
51
  @property
52
  def _identifying_params(self) -> Dict[str, Any]:
53
+ return {"model_id": self.model_id, "temperature": self.temperature, "max_tokens": self.max_tokens, "device": self.device}