Morgan Funtowicz commited on
Commit
90c13c1
·
1 Parent(s): 1e47a2c

misc(embeddings): clean up

Browse files
Files changed (2) hide show
  1. handler.py +25 -17
  2. requirements.txt +2 -0
handler.py CHANGED
@@ -1,16 +1,18 @@
 
1
  import platform
2
  from typing import Union, Sequence, Sized
3
 
4
  import torch
5
- from loguru import logger
6
  from hfendpoints.openai import Context, run
7
  from hfendpoints.openai.embeddings import Embedding, EmbeddingEndpoint, EmbeddingRequest, EmbeddingResponse, Usage
8
- from sentence_transformers import SentenceTransformer
9
-
10
  from hfendpoints import EndpointConfig, Handler, __version__
 
11
  from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF
 
12
 
13
-
 
 
14
 
15
  def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage:
16
  """
@@ -26,33 +28,39 @@ def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage:
26
 
27
  return Usage(prompt_tokens=num_tokens, total_tokens=num_tokens)
28
 
29
-
30
  class SentenceTransformerHandler(Handler):
31
- __slots__ = ("_config", "_model", "_model_config", "_model_name")
32
 
33
  def __init__(self, config: EndpointConfig):
34
  self._config = config
35
- self._model = SentenceTransformer(config.model_id, device="cpu", model_kwargs={"torch_dtype": "bfloat16"})
36
- self._model_config = self._model._modules['0'].auto_model.config
37
  self._model_name = config.model_id
38
 
39
- self._post_init()
40
 
41
- def _post_init(self):
42
- self._model = self._model.to(memory_format=torch.channels_last)
 
43
 
44
- if "Intel" in platform.processor():
45
  import intel_extension_for_pytorch as ipex
 
 
46
  with torch.inference_mode():
47
- self._model = self._model.eval()
48
- self._model = ipex.optimize(self._model, dtype=torch.float32, weights_prepack=False)
49
- self._model = torch.compile(self._model, backend="ipex")
 
50
  else:
51
- self._model = torch.compile(self._model)
 
 
 
 
52
 
53
  async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
54
  with torch.backends.mkldnn.verbose(VERBOSE_ON_CREATION if self._config.is_debug else VERBOSE_OFF):
55
- with torch.inference_mode(), torch.amp.autocast("cpu", dtype=torch.float32):
56
  tokens = self._model.tokenize(request.input)
57
  vectors = self._model.encode(request.input)
58
 
 
1
+ import os
2
  import platform
3
  from typing import Union, Sequence, Sized
4
 
5
  import torch
 
6
  from hfendpoints.openai import Context, run
7
  from hfendpoints.openai.embeddings import Embedding, EmbeddingEndpoint, EmbeddingRequest, EmbeddingResponse, Usage
 
 
8
  from hfendpoints import EndpointConfig, Handler, __version__
9
+ from loguru import logger
10
  from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF
11
+ from sentence_transformers import SentenceTransformer
12
 
13
+ # Not used for now
14
+ ENABLE_QUANTIZATION = bool(os.environ.get("HFENDPOINT_ENABLE_QUANTIZATION", "0"))
15
+ SUPPORTED_AMP_DTYPES = {torch.float32, torch.bfloat16}
16
 
17
  def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage:
18
  """
 
28
 
29
  return Usage(prompt_tokens=num_tokens, total_tokens=num_tokens)
30
 
 
31
  class SentenceTransformerHandler(Handler):
32
+ __slots__ = ("_config", "_dtype", "_model", "_model_name", "_use_amp")
33
 
34
  def __init__(self, config: EndpointConfig):
35
  self._config = config
36
+ self._dtype = torch.float32
 
37
  self._model_name = config.model_id
38
 
39
+ self._allocate_model()
40
 
41
+ def _allocate_model(self):
42
+ dtype = torch.bfloat16 if torch.cpu._is_avx512_bf16_supported() else torch.float32
43
+ model = SentenceTransformer(self._config.model_id, device="cpu", model_kwargs={"torch_dtype": dtype})
44
 
45
+ if platform.machine() == "x86_64":
46
  import intel_extension_for_pytorch as ipex
47
+ logger.info(f"x64 platform detected: {platform.processor()}")
48
+
49
  with torch.inference_mode():
50
+ model = model.eval()
51
+ model = model.to(memory_format=torch.channels_last)
52
+ model = ipex.optimize(model, dtype=dtype, weights_prepack=False, graph_mode=True, concat_linear=True)
53
+ model = torch.compile(model, dynamic=True, backend="ipex")
54
  else:
55
+ model = torch.compile(model)
56
+
57
+ self._model = model
58
+ self._dtype = dtype
59
+ self._use_amp = dtype in SUPPORTED_AMP_DTYPES
60
 
61
  async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
62
  with torch.backends.mkldnn.verbose(VERBOSE_ON_CREATION if self._config.is_debug else VERBOSE_OFF):
63
+ with torch.inference_mode(), torch.amp.autocast("cpu", dtype=self._dtype, enabled=self._use_amp):
64
  tokens = self._model.tokenize(request.input)
65
  vectors = self._model.encode(request.input)
66
 
requirements.txt CHANGED
@@ -1,3 +1,5 @@
 
1
  loguru>=0.7.3
2
  sentence-transformers
 
3
  torch>=2.5.0
 
1
+ # auto-round>=0.5.0
2
  loguru>=0.7.3
3
  sentence-transformers
4
+ # transformers>=4.51
5
  torch>=2.5.0