mcamargo00 commited on
Commit
a8f4e5d
·
verified ·
1 Parent(s): 62c79c5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -367,18 +367,21 @@ def load_model():
367
  base_phi_model = "microsoft/Phi-4-mini-instruct"
368
 
369
  # T4 does fp16 (not bf16)
370
- DTYPE = torch.float16
371
  quantization_config = BitsAndBytesConfig(
372
  load_in_4bit=True,
373
  bnb_4bit_quant_type="nf4",
374
  bnb_4bit_compute_dtype=DTYPE,
375
  )
 
 
376
  classifier_backbone_base = AutoModelForCausalLM.from_pretrained(
377
  base_phi_model,
378
  quantization_config=quantization_config,
379
- device_map={"": 0}, # single-GPU
380
- trust_remote_code=False, # <-- avoid remote LossKwargs import
381
- attn_implementation="sdpa",
 
382
  )
383
 
384
  classifier_tokenizer = AutoTokenizer.from_pretrained(
@@ -401,7 +404,7 @@ def load_model():
401
  classifier_model.classifier.load_state_dict(torch.load(classifier_head_path, map_location=device))
402
 
403
  classifier_model.to(device)
404
- classifier_model = classifier_model.to(torch.bfloat16)
405
 
406
  classifier_model.eval() # Set model to evaluation mode
407
 
 
367
  base_phi_model = "microsoft/Phi-4-mini-instruct"
368
 
369
  # T4 does fp16 (not bf16)
370
+ DTYPE = torch.float32
371
  quantization_config = BitsAndBytesConfig(
372
  load_in_4bit=True,
373
  bnb_4bit_quant_type="nf4",
374
  bnb_4bit_compute_dtype=DTYPE,
375
  )
376
+
377
+
378
  classifier_backbone_base = AutoModelForCausalLM.from_pretrained(
379
  base_phi_model,
380
  quantization_config=quantization_config,
381
+ device_map={"": 0},
382
+ trust_remote_code=False, # keep this if you switched it earlier
383
+ # safest with eager attention when mixing kernels:
384
+ attn_implementation="eager",
385
  )
386
 
387
  classifier_tokenizer = AutoTokenizer.from_pretrained(
 
404
  classifier_model.classifier.load_state_dict(torch.load(classifier_head_path, map_location=device))
405
 
406
  classifier_model.to(device)
407
+ classifier_model = classifier_model.to(device=DEVICE, dtype=torch.float32)
408
 
409
  classifier_model.eval() # Set model to evaluation mode
410