bwilkie commited on
Commit
55b792c
·
verified ·
1 Parent(s): e4f7d1f

Update myagent.py

Browse files
Files changed (1) hide show
  1. myagent.py +37 -20
myagent.py CHANGED
@@ -61,26 +61,43 @@ class LocalLlamaModel:
61
  self.device = model.device if hasattr(model, 'device') else 'cpu'
62
 
63
  def generate(self, prompt: str, max_new_tokens=512, **kwargs):
64
- # Generate answer using the provided prompt
65
- input_ids = self.tokenizer.apply_chat_template(
66
- [{"role": "user", "content": prompt}],
67
- add_generation_prompt=True,
68
- return_tensors="pt",
69
- tokenize=True,
70
- ).to(self.model.device)
71
-
72
- output = self.model.generate(
73
- input_ids,
74
- do_sample=True,
75
- temperature=0.3,
76
- min_p=0.15,
77
- repetition_penalty=1.05,
78
- max_new_tokens=max_new_tokens,
79
- )
80
-
81
- output = self.tokenizer.decode(output[0], skip_special_tokens=False)
82
- return output
83
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def __call__(self, prompt: str, max_new_tokens=512, **kwargs):
85
  """Make the model callable like a function"""
86
  return self.generate(prompt, max_new_tokens, **kwargs)
 
61
  self.device = model.device if hasattr(model, 'device') else 'cpu'
62
 
63
  def generate(self, prompt: str, max_new_tokens=512, **kwargs):
64
+ try:
65
+ # Generate answer using the provided prompt - following the recommended pattern
66
+ input_ids = self.tokenizer.apply_chat_template(
67
+ [{"role": "user", "content": str(prompt)}],
68
+ add_generation_prompt=True,
69
+ return_tensors="pt",
70
+ tokenize=True,
71
+ ).to(self.model.device)
72
+
73
+ # Generate output - exactly as in recommended code
74
+ output = self.model.generate(
75
+ input_ids,
76
+ do_sample=True,
77
+ temperature=0.3,
78
+ min_p=0.15,
79
+ repetition_penalty=1.05,
80
+ max_new_tokens=max_new_tokens,
81
+ )
82
+
83
+ # Decode the full output - as in recommended code
84
+ decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=False)
85
+
86
+ # Extract only the assistant's response (after the last <|im_start|>assistant)
87
+ if "<|im_start|>assistant" in decoded_output:
88
+ assistant_response = decoded_output.split("<|im_start|>assistant")[-1]
89
+ # Remove any trailing special tokens
90
+ assistant_response = assistant_response.replace("<|im_end|>", "").strip()
91
+ return assistant_response
92
+ else:
93
+ # Fallback: return the full decoded output
94
+ return decoded_output
95
+
96
+ except Exception as e:
97
+ print(f"Error in model generation: {e}")
98
+ return f"Error generating response: {str(e)}"
99
+
100
+
101
  def __call__(self, prompt: str, max_new_tokens=512, **kwargs):
102
  """Make the model callable like a function"""
103
  return self.generate(prompt, max_new_tokens, **kwargs)