Fas1 commited on
Commit
63c5e51
·
verified ·
1 Parent(s): 244294e
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -76,38 +76,41 @@ pipe = pipeline(
76
  tokenizer=tokenizer,
77
  )
78
 
79
- # Функция классификации (устойчивая к пустому вводу и разным форматам ответа)
80
- # Функция классификации (ручной вызов generate, без pipeline)
81
  def classify(text):
82
  if not text or not str(text).strip():
83
  return "⚠️ Пустой ввод. Введите сообщение."
84
 
85
  prompt = f"### Вопрос:\n{text}\n\n### Класс:"
86
  try:
87
- # Токенизируем вручную и генерируем напрямую через модель
88
  enc = tokenizer(
89
  prompt,
90
  return_tensors="pt",
91
  padding=True,
92
  truncation=True,
 
93
  )
94
- # Явно указываем pad/eos из токенизатора
 
 
 
 
95
  gen_kwargs = dict(
96
  max_new_tokens=16,
97
  do_sample=False,
98
  pad_token_id=tokenizer.pad_token_id,
99
  eos_token_id=tokenizer.eos_token_id,
 
100
  )
101
  with torch.no_grad():
102
- out = model.generate(**enc, **gen_kwargs)
103
- # Отрезаем подсказку и берём только продолжение
104
- gen_only = out[:, enc["input_ids"].shape[1]:]
105
  generated = tokenizer.decode(gen_only[0], skip_special_tokens=True)
106
  label = (generated.strip().split()[0].lower() if generated.strip() else "unknown")
107
  return f"🔍 Класс: **{label}**"
108
  except Exception as e:
109
  import traceback
110
- tb = traceback.format_exc(limit=3)
111
  return f"❌ Ошибка: {str(e)}\n\n<details><summary>trace</summary>\n\n{tb}\n\n</details>"
112
 
113
  # Интерфейс Gradio
 
76
  tokenizer=tokenizer,
77
  )
78
 
79
+ # Функция классификации (ручной вызов generate, явная передача тензоров)
 
80
  def classify(text):
81
  if not text or not str(text).strip():
82
  return "⚠️ Пустой ввод. Введите сообщение."
83
 
84
  prompt = f"### Вопрос:\n{text}\n\n### Класс:"
85
  try:
 
86
  enc = tokenizer(
87
  prompt,
88
  return_tensors="pt",
89
  padding=True,
90
  truncation=True,
91
+ max_length=min(2048, getattr(tokenizer, "model_max_length", 2048) or 2048),
92
  )
93
+ input_ids = enc["input_ids"]
94
+ attention_mask = enc.get("attention_mask")
95
+ if attention_mask is None:
96
+ attention_mask = torch.ones_like(input_ids)
97
+
98
  gen_kwargs = dict(
99
  max_new_tokens=16,
100
  do_sample=False,
101
  pad_token_id=tokenizer.pad_token_id,
102
  eos_token_id=tokenizer.eos_token_id,
103
+ use_cache=True,
104
  )
105
  with torch.no_grad():
106
+ out = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
107
+ gen_only = out[:, input_ids.shape[1]:]
 
108
  generated = tokenizer.decode(gen_only[0], skip_special_tokens=True)
109
  label = (generated.strip().split()[0].lower() if generated.strip() else "unknown")
110
  return f"🔍 Класс: **{label}**"
111
  except Exception as e:
112
  import traceback
113
+ tb = traceback.format_exc(limit=5)
114
  return f"❌ Ошибка: {str(e)}\n\n<details><summary>trace</summary>\n\n{tb}\n\n</details>"
115
 
116
  # Интерфейс Gradio