FelixPhilip commited on
Commit
ba26d2b
·
1 Parent(s): 9c2accc

Oracle weight assigning update

Browse files
Files changed (1) hide show
  1. Oracle/SmolLM.py +16 -7
Oracle/SmolLM.py CHANGED
@@ -1,29 +1,38 @@
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
 
3
  class SmolLM:
4
  def __init__(self, model_path="HuggingFaceTB/SmolLM2-1.7B-Instruct"):
5
  self.available = True
 
6
  try:
7
  print(f"[INFO] Loading Oracle tokenizer from {model_path}")
8
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
9
- print(f"[INFO] Loading Oracle from {model_path}")
10
- self.model = AutoModelForCausalLM.from_pretrained(model_path)
11
  print("[INFO] Oracle loaded successfully")
12
  except Exception as e:
13
  print(f"[ERROR] Failed to load model '{model_path}': {e}")
14
  self.available = False
15
 
16
- def predict(self, prompt,max_length=512,max_new_tokens=150):
17
  if not self.available:
18
  print("[WARN] Oracle unavailable, returning default weight 0.5")
19
  return "0.5"
20
  try:
21
- print(f"[INFO] Generating response for prompt: {prompt[:100]}...", flush=True)
22
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
23
- outputs = self.model.generate(**inputs, max_length=inputs["input_ids"].shape[1]+max_new_tokens,num_return_sequences=1)
 
 
 
 
 
 
 
24
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
25
  print(f"[INFO] Generated response: {response[:100]}...", flush=True)
26
  return response
27
  except Exception as e:
28
  print(f"[ERROR] Oracle has failed: {e}")
29
- return "0.5"
 
1
+ import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
  class SmolLM:
5
  def __init__(self, model_path="HuggingFaceTB/SmolLM2-1.7B-Instruct"):
6
  self.available = True
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
  try:
9
  print(f"[INFO] Loading Oracle tokenizer from {model_path}")
10
  self.tokenizer = AutoTokenizer.from_pretrained(model_path)
11
+ print(f"[INFO] Loading Oracle from {model_path} on {self.device}")
12
+ self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
13
  print("[INFO] Oracle loaded successfully")
14
  except Exception as e:
15
  print(f"[ERROR] Failed to load model '{model_path}': {e}")
16
  self.available = False
17
 
18
+ def predict(self, prompt, max_length=512, max_new_tokens=150):
19
  if not self.available:
20
  print("[WARN] Oracle unavailable, returning default weight 0.5")
21
  return "0.5"
22
  try:
23
+ # Use chat template as per documentation
24
+ messages = [{"role": "user", "content": prompt}]
25
+ inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
26
+ outputs = self.model.generate(
27
+ inputs,
28
+ max_new_tokens=max_new_tokens,
29
+ temperature=0.7,
30
+ top_p=0.9,
31
+ do_sample=True
32
+ )
33
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
34
  print(f"[INFO] Generated response: {response[:100]}...", flush=True)
35
  return response
36
  except Exception as e:
37
  print(f"[ERROR] Oracle has failed: {e}")
38
+ return "0.5"