konieshadow commited on
Commit
089bc3b
·
1 Parent(s): 5a751c2

移除推理完成接口及相关逻辑,调整GPU持续时间为180秒,简化代码结构以提高可维护性。

Browse files
src/podcast_transcribe/llm/llm_router.py CHANGED
@@ -226,76 +226,6 @@ class LLMRouter:
226
  logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
227
  raise RuntimeError(f"聊天完成失败: {str(e)}")
228
 
229
- def reasoning_completion(
230
- self,
231
- messages: List[Dict[str, str]],
232
- provider: str = "gemma-transformers",
233
- temperature: float = 0.3,
234
- max_tokens: int = 2048,
235
- top_p: float = 0.9,
236
- model: Optional[str] = None,
237
- extract_reasoning_steps: bool = True,
238
- **kwargs
239
- ) -> Dict[str, Any]:
240
- """
241
- 专门用于推理任务的聊天完成接口
242
-
243
- 参数:
244
- messages: 消息列表,每个消息包含role和content
245
- provider: LLM提供者名称,默认使用gemma-transformers
246
- temperature: 温度参数(推理任务建议使用较低值)
247
- max_tokens: 最大生成token数
248
- top_p: nucleus采样参数
249
- model: 可选的模型名称
250
- extract_reasoning_steps: 是否提取推理步骤
251
- **kwargs: 其他参数
252
-
253
- 返回:
254
- 包含推理步骤的响应字典
255
- """
256
- logger.info(f"使用provider '{provider}' 进行推理完成,消息数量: {len(messages)}")
257
-
258
- # 确保使用支持推理的provider
259
- if provider not in ["gemma-transformers"]:
260
- logger.warning(f"Provider '{provider}' 可能不支持推理功能,建议使用 'gemma-transformers'")
261
-
262
- try:
263
- # 如果提供了model参数,添加到kwargs中
264
- if model is not None:
265
- kwargs["model_name"] = model
266
-
267
- # 获取或创建LLM实例
268
- llm_instance = self._get_or_create_instance(provider, **kwargs)
269
-
270
- # 检查实例是否支持推理完成
271
- if hasattr(llm_instance, 'reasoning_completion'):
272
- result = llm_instance.reasoning_completion(
273
- messages=messages,
274
- temperature=temperature,
275
- max_tokens=max_tokens,
276
- top_p=top_p,
277
- extract_reasoning_steps=extract_reasoning_steps,
278
- **kwargs
279
- )
280
- else:
281
- # 回退到普通聊天完成
282
- logger.warning(f"Provider '{provider}' 不支持推理完成,回退到普通聊天完成")
283
- result = llm_instance.create(
284
- messages=messages,
285
- temperature=temperature,
286
- max_tokens=max_tokens,
287
- top_p=top_p,
288
- model=model,
289
- **kwargs
290
- )
291
-
292
- logger.info(f"推理完成成功,使用tokens: {result.get('usage', {}).get('total_tokens', 'unknown')}")
293
- return result
294
-
295
- except Exception as e:
296
- logger.error(f"使用provider '{provider}' 进行推理完成失败: {str(e)}", exc_info=True)
297
- raise RuntimeError(f"推理完成失败: {str(e)}")
298
-
299
  def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
300
  """
301
  获取模型信息
@@ -356,7 +286,7 @@ class LLMRouter:
356
  # 创建全局路由器实例
357
  _router = LLMRouter()
358
 
359
- @spaces.GPU(duration=60)
360
  def chat_completion(
361
  messages: List[Dict[str, str]],
362
  provider: str = "gemma-transformers",
@@ -432,72 +362,6 @@ def chat_completion(
432
  **params
433
  )
434
 
435
- @spaces.GPU(duration=60)
436
- def reasoning_completion(
437
- messages: List[Dict[str, str]],
438
- provider: str = "gemma-transformers",
439
- temperature: float = 0.3,
440
- max_tokens: int = 2048,
441
- top_p: float = 0.9,
442
- model: Optional[str] = None,
443
- device: Optional[str] = None,
444
- device_map: Optional[str] = None,
445
- extract_reasoning_steps: bool = True,
446
- **kwargs
447
- ) -> Dict[str, Any]:
448
- """
449
- 专门用于推理任务的聊天完成接口函数
450
-
451
- 参数:
452
- messages: 消息列表,每个消息包含role和content字段
453
- provider: LLM提供者,默认使用gemma-transformers
454
- temperature: 温度参数(推理任务建议使用较低值)
455
- max_tokens: 最大生成token数
456
- top_p: nucleus采样参数
457
- model: 模型名称,如果不指定则使用默认模型
458
- device: 推理设备
459
- device_map: 设备映射配置
460
- extract_reasoning_steps: 是否提取推理步骤
461
- **kwargs: 其他参数
462
-
463
- 返回:
464
- 包含推理步骤的响应字典
465
-
466
- 示例:
467
- # 数学推理任务
468
- response = reasoning_completion(
469
- messages=[{"role": "user", "content": "解这个方程:3x + 7 = 22"}],
470
- provider="gemma-transformers",
471
- extract_reasoning_steps=True
472
- )
473
-
474
- # 逻辑推理任务
475
- response = reasoning_completion(
476
- messages=[{"role": "user", "content": "如果所有的猫都是动物,而小花是一只猫,那么小花是什么?"}],
477
- provider="gemma-transformers",
478
- temperature=0.2
479
- )
480
- """
481
- # 准备参数
482
- params = kwargs.copy()
483
- if model is not None:
484
- params["model_name"] = model
485
- if device is not None:
486
- params["device"] = device
487
- if device_map:
488
- params["device_map"] = device_map
489
-
490
- return _router.reasoning_completion(
491
- messages=messages,
492
- provider=provider,
493
- temperature=temperature,
494
- max_tokens=max_tokens,
495
- top_p=top_p,
496
- model=model,
497
- extract_reasoning_steps=extract_reasoning_steps,
498
- **params
499
- )
500
-
501
 
502
  def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
503
  """
 
226
  logger.error(f"使用provider '{provider}' 进行聊天完成失败: {str(e)}", exc_info=True)
227
  raise RuntimeError(f"聊天完成失败: {str(e)}")
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def get_model_info(self, provider: str, **kwargs) -> Dict[str, Any]:
230
  """
231
  获取模型信息
 
286
  # 创建全局路由器实例
287
  _router = LLMRouter()
288
 
289
+ @spaces.GPU(duration=180)
290
  def chat_completion(
291
  messages: List[Dict[str, str]],
292
  provider: str = "gemma-transformers",
 
362
  **params
363
  )
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  def get_model_info(provider: str = "gemma-mlx", **kwargs) -> Dict[str, Any]:
367
  """