Morgan Funtowicz commited on
Commit
a8540ed
·
1 Parent(s): 90c13c1

misc(config): add proper way to detect if cpu may support bfloat16

Browse files
Files changed (1) hide show
  1. handler.py +16 -2
handler.py CHANGED
@@ -11,9 +11,23 @@ from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF
11
  from sentence_transformers import SentenceTransformer
12
 
13
  # Not used for now
14
- ENABLE_QUANTIZATION = bool(os.environ.get("HFENDPOINT_ENABLE_QUANTIZATION", "0"))
15
  SUPPORTED_AMP_DTYPES = {torch.float32, torch.bfloat16}
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage:
18
  """
19
  Compute the number of processed tokens and return as Usage object matching OpenAI
@@ -39,7 +53,7 @@ class SentenceTransformerHandler(Handler):
39
  self._allocate_model()
40
 
41
  def _allocate_model(self):
42
- dtype = torch.bfloat16 if torch.cpu._is_avx512_bf16_supported() else torch.float32
43
  model = SentenceTransformer(self._config.model_id, device="cpu", model_kwargs={"torch_dtype": dtype})
44
 
45
  if platform.machine() == "x86_64":
 
11
  from sentence_transformers import SentenceTransformer
12
 
13
  # Not used for now
 
14
  SUPPORTED_AMP_DTYPES = {torch.float32, torch.bfloat16}
15
 
16
+
17
+ def has_bf16_support() -> bool:
18
+ """
19
+ Helper to detect if the hardware supports bfloat16
20
+
21
+ Note:
22
+ Intel libraries, such as oneDNN, provide emulation for bfloat16 even if the underlying hardware does not support it.
23
+ This means CPU ISA with AVX512 will work, even if not with the same performances as one could expect from CPU ISA with AVX512_BF16.
24
+ Also, AMX_BF16 is implicitly assumed true when AVX512_BF16 is true (that's the case on Intel Sapphire Rapids).
25
+
26
+ :return: True if the hardware supports (or can emulate) bfloat16, False otherwise
27
+ """
28
+ return torch.cpu._is_avx512_bf16_supported() or torch.cpu._is_avx512_supported()
29
+
30
+
31
  def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage:
32
  """
33
  Compute the number of processed tokens and return as Usage object matching OpenAI
 
53
  self._allocate_model()
54
 
55
  def _allocate_model(self):
56
+ dtype = torch.bfloat16 if has_bf16_support() else torch.float32
57
  model = SentenceTransformer(self._config.model_id, device="cpu", model_kwargs={"torch_dtype": dtype})
58
 
59
  if platform.machine() == "x86_64":