konieshadow commited on
Commit
ae20fe2
·
1 Parent(s): 089bc3b

添加对PyTorch编译的禁用支持,以解决Gradio Spaces中的兼容性问题,并在多个文件中统一配置日志记录。

Browse files
src/podcast_transcribe/llm/llm_base.py CHANGED
@@ -1,9 +1,29 @@
 
 
 
 
 
 
1
  import time
2
  import uuid
3
  import torch
4
  from typing import List, Dict, Optional, Union, Literal
5
  from abc import ABC, abstractmethod
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class BaseChatCompletion(ABC):
9
  """Gemma 聊天完成的基类,包含公共功能"""
@@ -308,6 +328,16 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
308
  except ImportError:
309
  raise ImportError("请先安装 transformers 库: pip install transformers")
310
 
 
 
 
 
 
 
 
 
 
 
311
  print(f"正在加载模型: {self.model_name}")
312
  print(f"目标设备: {self.device}")
313
  print(f"设备映射: {self.device_map}")
@@ -372,6 +402,13 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
372
  ) -> str:
373
  """使用 transformers 生成响应"""
374
 
 
 
 
 
 
 
 
375
  # 对提示进行编码
376
  inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
377
 
@@ -488,6 +525,29 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
488
  print(f"生成完成,输出长度: {len(generated_tokens)} tokens")
489
  return generated_text
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  except RuntimeError as e:
492
  if "CUDA error" in str(e):
493
  print(f"CUDA 错误,尝试使用 CPU 进行推理: {e}")
@@ -517,10 +577,42 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
517
  else:
518
  raise e
519
  except Exception as e:
520
- print(f"生成响应时出错: {e}")
521
- import traceback
522
- traceback.print_exc()
523
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
526
  """获取模型信息"""
 
1
+ """
2
+ LLM基础类定义
3
+ 提供聊天完成功能的抽象基类和Transformers实现
4
+ """
5
+
6
+ import logging
7
  import time
8
  import uuid
9
  import torch
10
  from typing import List, Dict, Optional, Union, Literal
11
  from abc import ABC, abstractmethod
12
+ import os
13
+
14
+ # 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
15
+ os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
16
 
17
+ # 如果 torch._dynamo 可用,禁用它
18
+ try:
19
+ import torch._dynamo
20
+ torch._dynamo.config.disable = True
21
+ torch._dynamo.config.suppress_errors = True
22
+ except ImportError:
23
+ pass
24
+
25
+ # 配置日志
26
+ logger = logging.getLogger("llm")
27
 
28
  class BaseChatCompletion(ABC):
29
  """Gemma 聊天完成的基类,包含公共功能"""
 
328
  except ImportError:
329
  raise ImportError("请先安装 transformers 库: pip install transformers")
330
 
331
+ # 确保编译功能被禁用
332
+ os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
333
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
334
+ try:
335
+ import torch._dynamo
336
+ torch._dynamo.config.disable = True
337
+ torch._dynamo.config.suppress_errors = True
338
+ except (ImportError, AttributeError):
339
+ pass
340
+
341
  print(f"正在加载模型: {self.model_name}")
342
  print(f"目标设备: {self.device}")
343
  print(f"设备映射: {self.device_map}")
 
402
  ) -> str:
403
  """使用 transformers 生成响应"""
404
 
405
+ # 额外的编译禁用措施,确保在 Gradio Spaces 中正常工作
406
+ try:
407
+ import torch._dynamo
408
+ torch._dynamo.config.disable = True
409
+ except (ImportError, AttributeError):
410
+ pass
411
+
412
  # 对提示进行编码
413
  inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
414
 
 
525
  print(f"生成完成,输出长度: {len(generated_tokens)} tokens")
526
  return generated_text
527
 
528
+ except torch._dynamo.exc.BackendCompilerFailed as e:
529
+ print(f"PyTorch 编译器错误,尝试禁用编译后重试: {e}")
530
+ # 强制禁用编译并重试
531
+ try:
532
+ torch._dynamo.reset()
533
+ torch._dynamo.config.disable = True
534
+ os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
535
+
536
+ with torch.no_grad():
537
+ outputs = self.model.generate(
538
+ inputs,
539
+ **generation_config
540
+ )
541
+
542
+ generated_tokens = outputs[0][len(inputs[0]):]
543
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
544
+
545
+ print(f"禁用编译后生成完成,输出长度: {len(generated_tokens)} tokens")
546
+ return generated_text
547
+
548
+ except Exception as retry_e:
549
+ print(f"禁用编译后仍然失败: {retry_e}")
550
+ raise e
551
  except RuntimeError as e:
552
  if "CUDA error" in str(e):
553
  print(f"CUDA 错误,尝试使用 CPU 进行推理: {e}")
 
577
  else:
578
  raise e
579
  except Exception as e:
580
+ # 处理其他编译器相关错误
581
+ if "BackendCompilerFailed" in str(e) or "dynamo" in str(e).lower() or "inductor" in str(e).lower():
582
+ print(f"检测到编译器相关错误,尝试完全禁用编译: {e}")
583
+ try:
584
+ # 强制禁用所有编译功能
585
+ os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
586
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
587
+
588
+ # 如果可能,重置编译状态
589
+ try:
590
+ torch._dynamo.reset()
591
+ torch._dynamo.config.disable = True
592
+ torch._dynamo.config.suppress_errors = True
593
+ except:
594
+ pass
595
+
596
+ with torch.no_grad():
597
+ outputs = self.model.generate(
598
+ inputs,
599
+ **generation_config
600
+ )
601
+
602
+ generated_tokens = outputs[0][len(inputs[0]):]
603
+ generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
604
+
605
+ print(f"完全禁用编译后生成完成,输出长度: {len(generated_tokens)} tokens")
606
+ return generated_text
607
+
608
+ except Exception as final_e:
609
+ print(f"所有重试都失败: {final_e}")
610
+ raise e
611
+ else:
612
+ print(f"生成响应时出错: {e}")
613
+ import traceback
614
+ traceback.print_exc()
615
+ raise
616
 
617
  def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
618
  """获取模型信息"""
src/podcast_transcribe/llm/llm_gemma_transfomers.py CHANGED
@@ -1,6 +1,20 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  from typing import List, Dict, Optional, Union, Literal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from .llm_base import TransformersBaseChatCompletion
5
 
6
 
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  from typing import List, Dict, Optional, Union, Literal
4
+ import os
5
+
6
+ # 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
7
+ os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
8
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
9
+
10
+ # 如果 torch._dynamo 可用,禁用它
11
+ try:
12
+ import torch._dynamo
13
+ torch._dynamo.config.disable = True
14
+ torch._dynamo.config.suppress_errors = True
15
+ except ImportError:
16
+ pass
17
+
18
  from .llm_base import TransformersBaseChatCompletion
19
 
20
 
src/podcast_transcribe/llm/llm_router.py CHANGED
@@ -6,6 +6,19 @@ LLM模型调用路由器
6
  import logging
7
  import torch
8
  from typing import Dict, Any, Optional, List, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  import spaces
11
  from .llm_base import BaseChatCompletion
 
6
  import logging
7
  import torch
8
  from typing import Dict, Any, Optional, List, Union
9
+ import os
10
+
11
+ # 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
12
+ os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
13
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
14
+
15
+ # 如果 torch._dynamo 可用,禁用它
16
+ try:
17
+ import torch._dynamo
18
+ torch._dynamo.config.disable = True
19
+ torch._dynamo.config.suppress_errors = True
20
+ except ImportError:
21
+ pass
22
 
23
  import spaces
24
  from .llm_base import BaseChatCompletion