CodCodingCode commited on
Commit
d27aed9
·
verified ·
1 Parent(s): f13e9e9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -9
handler.py CHANGED
@@ -1,18 +1,54 @@
1
  from typing import Dict, List, Any
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
-
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
- # Load model and tokenizer
9
- self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  self.model = AutoModelForCausalLM.from_pretrained(
11
- path,
12
  device_map="auto",
13
  torch_dtype=torch.bfloat16,
 
14
  )
15
 
 
 
 
 
16
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
  """
18
  data args:
@@ -21,20 +57,31 @@ class EndpointHandler():
21
  Return:
22
  A :obj:`list` | `dict`: will be serialized and returned
23
  """
24
-
25
  # Get the input text
26
  inputs = data.pop("inputs", data)
27
  parameters = data.pop("parameters", {})
28
 
 
 
 
 
 
 
29
  # Set default parameters
30
  max_new_tokens = parameters.get("max_new_tokens", 1000)
31
  temperature = parameters.get("temperature", 0.1)
32
  do_sample = parameters.get("do_sample", True)
33
  top_p = parameters.get("top_p", 0.9)
34
- return_full_text = parameters.get("return_full_text", True)
35
 
36
  # Tokenize the input
37
- input_ids = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
38
 
39
  # Generate text
40
  with torch.no_grad():
@@ -44,7 +91,7 @@ class EndpointHandler():
44
  temperature=temperature,
45
  do_sample=do_sample,
46
  top_p=top_p,
47
- pad_token_id=self.tokenizer.eos_token_id,
48
  eos_token_id=self.tokenizer.eos_token_id,
49
  )
50
 
@@ -55,5 +102,5 @@ class EndpointHandler():
55
  # Only return the newly generated part
56
  new_tokens = generated_ids[0][input_ids["input_ids"].shape[1]:]
57
  generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
58
-
59
  return [{"generated_text": generated_text}]
 
1
  from typing import Dict, List, Any
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
 
6
  class EndpointHandler():
7
  def __init__(self, path=""):
8
+ # Look for checkpoint-100 folder
9
+ checkpoint_path = None
10
+
11
+ if not path or path == "/repository":
12
+ base_path = "."
13
+ else:
14
+ base_path = path
15
+
16
+ # Check different possible locations
17
+ possible_paths = [
18
+ os.path.join(base_path, "checkpoint-100"),
19
+ os.path.join(".", "checkpoint-100"),
20
+ os.path.join("/repository", "checkpoint-100"),
21
+ "checkpoint-100"
22
+ ]
23
+
24
+ for check_path in possible_paths:
25
+ if os.path.exists(check_path) and os.path.isdir(check_path):
26
+ # Verify it contains model files
27
+ files = os.listdir(check_path)
28
+ if any(f in files for f in ['config.json', 'pytorch_model.bin', 'model.safetensors']):
29
+ checkpoint_path = check_path
30
+ break
31
+
32
+ if checkpoint_path is None:
33
+ print(f"Available files in base path: {os.listdir(base_path) if os.path.exists(base_path) else 'Path does not exist'}")
34
+ raise ValueError("Could not find checkpoint-100 folder with model files")
35
+
36
+ print(f"Loading model from: {checkpoint_path}")
37
+ print(f"Files in checkpoint: {os.listdir(checkpoint_path)}")
38
+
39
+ # Load model and tokenizer from checkpoint-100
40
+ self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
41
  self.model = AutoModelForCausalLM.from_pretrained(
42
+ checkpoint_path,
43
  device_map="auto",
44
  torch_dtype=torch.bfloat16,
45
+ trust_remote_code=True,
46
  )
47
 
48
+ # Set pad token if not exists
49
+ if self.tokenizer.pad_token is None:
50
+ self.tokenizer.pad_token = self.tokenizer.eos_token
51
+
52
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
53
  """
54
  data args:
 
57
  Return:
58
  A :obj:`list` | `dict`: will be serialized and returned
59
  """
 
60
  # Get the input text
61
  inputs = data.pop("inputs", data)
62
  parameters = data.pop("parameters", {})
63
 
64
+ # Handle string input directly
65
+ if isinstance(inputs, str):
66
+ input_text = inputs
67
+ else:
68
+ input_text = str(inputs)
69
+
70
  # Set default parameters
71
  max_new_tokens = parameters.get("max_new_tokens", 1000)
72
  temperature = parameters.get("temperature", 0.1)
73
  do_sample = parameters.get("do_sample", True)
74
  top_p = parameters.get("top_p", 0.9)
75
+ return_full_text = parameters.get("return_full_text", False)
76
 
77
  # Tokenize the input
78
+ input_ids = self.tokenizer(
79
+ input_text,
80
+ return_tensors="pt",
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=2048
84
+ ).to(self.model.device)
85
 
86
  # Generate text
87
  with torch.no_grad():
 
91
  temperature=temperature,
92
  do_sample=do_sample,
93
  top_p=top_p,
94
+ pad_token_id=self.tokenizer.pad_token_id,
95
  eos_token_id=self.tokenizer.eos_token_id,
96
  )
97
 
 
102
  # Only return the newly generated part
103
  new_tokens = generated_ids[0][input_ids["input_ids"].shape[1]:]
104
  generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
105
+
106
  return [{"generated_text": generated_text}]