Spaces:
Running
Running
Commit
·
ae20fe2
1
Parent(s):
089bc3b
添加对PyTorch编译的禁用支持,以解决Gradio Spaces中的兼容性问题,并在多个文件中统一配置日志记录。
Browse files
src/podcast_transcribe/llm/llm_base.py
CHANGED
@@ -1,9 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import time
|
2 |
import uuid
|
3 |
import torch
|
4 |
from typing import List, Dict, Optional, Union, Literal
|
5 |
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class BaseChatCompletion(ABC):
|
9 |
"""Gemma 聊天完成的基类,包含公共功能"""
|
@@ -308,6 +328,16 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
308 |
except ImportError:
|
309 |
raise ImportError("请先安装 transformers 库: pip install transformers")
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
print(f"正在加载模型: {self.model_name}")
|
312 |
print(f"目标设备: {self.device}")
|
313 |
print(f"设备映射: {self.device_map}")
|
@@ -372,6 +402,13 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
372 |
) -> str:
|
373 |
"""使用 transformers 生成响应"""
|
374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
# 对提示进行编码
|
376 |
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
377 |
|
@@ -488,6 +525,29 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
488 |
print(f"生成完成,输出长度: {len(generated_tokens)} tokens")
|
489 |
return generated_text
|
490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
except RuntimeError as e:
|
492 |
if "CUDA error" in str(e):
|
493 |
print(f"CUDA 错误,尝试使用 CPU 进行推理: {e}")
|
@@ -517,10 +577,42 @@ class TransformersBaseChatCompletion(BaseChatCompletion):
|
|
517 |
else:
|
518 |
raise e
|
519 |
except Exception as e:
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
|
525 |
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
526 |
"""获取模型信息"""
|
|
|
1 |
+
"""
|
2 |
+
LLM基础类定义
|
3 |
+
提供聊天完成功能的抽象基类和Transformers实现
|
4 |
+
"""
|
5 |
+
|
6 |
+
import logging
|
7 |
import time
|
8 |
import uuid
|
9 |
import torch
|
10 |
from typing import List, Dict, Optional, Union, Literal
|
11 |
from abc import ABC, abstractmethod
|
12 |
+
import os
|
13 |
+
|
14 |
+
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
|
15 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
16 |
|
17 |
+
# 如果 torch._dynamo 可用,禁用它
|
18 |
+
try:
|
19 |
+
import torch._dynamo
|
20 |
+
torch._dynamo.config.disable = True
|
21 |
+
torch._dynamo.config.suppress_errors = True
|
22 |
+
except ImportError:
|
23 |
+
pass
|
24 |
+
|
25 |
+
# 配置日志
|
26 |
+
logger = logging.getLogger("llm")
|
27 |
|
28 |
class BaseChatCompletion(ABC):
|
29 |
"""Gemma 聊天完成的基类,包含公共功能"""
|
|
|
328 |
except ImportError:
|
329 |
raise ImportError("请先安装 transformers 库: pip install transformers")
|
330 |
|
331 |
+
# 确保编译功能被禁用
|
332 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
333 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
334 |
+
try:
|
335 |
+
import torch._dynamo
|
336 |
+
torch._dynamo.config.disable = True
|
337 |
+
torch._dynamo.config.suppress_errors = True
|
338 |
+
except (ImportError, AttributeError):
|
339 |
+
pass
|
340 |
+
|
341 |
print(f"正在加载模型: {self.model_name}")
|
342 |
print(f"目标设备: {self.device}")
|
343 |
print(f"设备映射: {self.device_map}")
|
|
|
402 |
) -> str:
|
403 |
"""使用 transformers 生成响应"""
|
404 |
|
405 |
+
# 额外的编译禁用措施,确保在 Gradio Spaces 中正常工作
|
406 |
+
try:
|
407 |
+
import torch._dynamo
|
408 |
+
torch._dynamo.config.disable = True
|
409 |
+
except (ImportError, AttributeError):
|
410 |
+
pass
|
411 |
+
|
412 |
# 对提示进行编码
|
413 |
inputs = self.tokenizer.encode(prompt_str, return_tensors="pt")
|
414 |
|
|
|
525 |
print(f"生成完成,输出长度: {len(generated_tokens)} tokens")
|
526 |
return generated_text
|
527 |
|
528 |
+
except torch._dynamo.exc.BackendCompilerFailed as e:
|
529 |
+
print(f"PyTorch 编译器错误,尝试禁用编译后重试: {e}")
|
530 |
+
# 强制禁用编译并重试
|
531 |
+
try:
|
532 |
+
torch._dynamo.reset()
|
533 |
+
torch._dynamo.config.disable = True
|
534 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
535 |
+
|
536 |
+
with torch.no_grad():
|
537 |
+
outputs = self.model.generate(
|
538 |
+
inputs,
|
539 |
+
**generation_config
|
540 |
+
)
|
541 |
+
|
542 |
+
generated_tokens = outputs[0][len(inputs[0]):]
|
543 |
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
544 |
+
|
545 |
+
print(f"禁用编译后生成完成,输出长度: {len(generated_tokens)} tokens")
|
546 |
+
return generated_text
|
547 |
+
|
548 |
+
except Exception as retry_e:
|
549 |
+
print(f"禁用编译后仍然失败: {retry_e}")
|
550 |
+
raise e
|
551 |
except RuntimeError as e:
|
552 |
if "CUDA error" in str(e):
|
553 |
print(f"CUDA 错误,尝试使用 CPU 进行推理: {e}")
|
|
|
577 |
else:
|
578 |
raise e
|
579 |
except Exception as e:
|
580 |
+
# 处理其他编译器相关错误
|
581 |
+
if "BackendCompilerFailed" in str(e) or "dynamo" in str(e).lower() or "inductor" in str(e).lower():
|
582 |
+
print(f"检测到编译器相关错误,尝试完全禁用编译: {e}")
|
583 |
+
try:
|
584 |
+
# 强制禁用所有编译功能
|
585 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
586 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
587 |
+
|
588 |
+
# 如果可能,重置编译状态
|
589 |
+
try:
|
590 |
+
torch._dynamo.reset()
|
591 |
+
torch._dynamo.config.disable = True
|
592 |
+
torch._dynamo.config.suppress_errors = True
|
593 |
+
except:
|
594 |
+
pass
|
595 |
+
|
596 |
+
with torch.no_grad():
|
597 |
+
outputs = self.model.generate(
|
598 |
+
inputs,
|
599 |
+
**generation_config
|
600 |
+
)
|
601 |
+
|
602 |
+
generated_tokens = outputs[0][len(inputs[0]):]
|
603 |
+
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
604 |
+
|
605 |
+
print(f"完全禁用编译后生成完成,输出长度: {len(generated_tokens)} tokens")
|
606 |
+
return generated_text
|
607 |
+
|
608 |
+
except Exception as final_e:
|
609 |
+
print(f"所有重试都失败: {final_e}")
|
610 |
+
raise e
|
611 |
+
else:
|
612 |
+
print(f"生成响应时出错: {e}")
|
613 |
+
import traceback
|
614 |
+
traceback.print_exc()
|
615 |
+
raise
|
616 |
|
617 |
def get_model_info(self) -> Dict[str, Union[str, bool, int]]:
|
618 |
"""获取模型信息"""
|
src/podcast_transcribe/llm/llm_gemma_transfomers.py
CHANGED
@@ -1,6 +1,20 @@
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
from typing import List, Dict, Optional, Union, Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from .llm_base import TransformersBaseChatCompletion
|
5 |
|
6 |
|
|
|
1 |
import torch
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
3 |
from typing import List, Dict, Optional, Union, Literal
|
4 |
+
import os
|
5 |
+
|
6 |
+
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
|
7 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
8 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
9 |
+
|
10 |
+
# 如果 torch._dynamo 可用,禁用它
|
11 |
+
try:
|
12 |
+
import torch._dynamo
|
13 |
+
torch._dynamo.config.disable = True
|
14 |
+
torch._dynamo.config.suppress_errors = True
|
15 |
+
except ImportError:
|
16 |
+
pass
|
17 |
+
|
18 |
from .llm_base import TransformersBaseChatCompletion
|
19 |
|
20 |
|
src/podcast_transcribe/llm/llm_router.py
CHANGED
@@ -6,6 +6,19 @@ LLM模型调用路由器
|
|
6 |
import logging
|
7 |
import torch
|
8 |
from typing import Dict, Any, Optional, List, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
import spaces
|
11 |
from .llm_base import BaseChatCompletion
|
|
|
6 |
import logging
|
7 |
import torch
|
8 |
from typing import Dict, Any, Optional, List, Union
|
9 |
+
import os
|
10 |
+
|
11 |
+
# 禁用 PyTorch 编译以避免在 Gradio Spaces 中的兼容性问题
|
12 |
+
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
|
13 |
+
os.environ["TORCH_COMPILE_DISABLE"] = "1"
|
14 |
+
|
15 |
+
# 如果 torch._dynamo 可用,禁用它
|
16 |
+
try:
|
17 |
+
import torch._dynamo
|
18 |
+
torch._dynamo.config.disable = True
|
19 |
+
torch._dynamo.config.suppress_errors = True
|
20 |
+
except ImportError:
|
21 |
+
pass
|
22 |
|
23 |
import spaces
|
24 |
from .llm_base import BaseChatCompletion
|