File size: 3,430 Bytes
924aa01
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
924aa01
8289369
642af4d
8289369
 
 
 
 
 
642af4d
 
8289369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch # 导入 torch
# 添加项目根目录到Python路径
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion
from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion


if __name__ == "__main__":
    # 示例用法:
    print("正在初始化 LLM 聊天补全...")
    try:
        # model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ"
        model_name = "google/gemma-3-4b-it"
        device = "cuda"

        # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ")
        # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit"
        if model_name.startswith("mlx-community"):
            gemma_chat = GemmaMLXChatCompletion(model_name=model_name)
        else:
            # 如果设备是 mps 或 cuda,则使用 float32 以增加稳定性
            gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, device=device)

        print("\n--- 示例 1: 简单用户查询 ---")
        messages_example1 = [
            {"role": "user", "content": "你好,你是谁?"}
        ]
        response1 = gemma_chat.create(messages=messages_example1, max_tokens=50)
        print("响应 1:")
        print(f"  助手: {response1['choices'][0]['message']['content']}")
        print(f"  用量: {response1['usage']}")

        print("\n--- 示例 2: 带历史记录的对话 ---")
        messages_example2 = [
            {"role": "user", "content": "法国的首都是哪里?"},
            {"role": "assistant", "content": "法国的首都是巴黎。"},
            {"role": "user", "content": "你能告诉我一个关于它的有趣事实吗?"}
        ]
        response2 = gemma_chat.create(messages=messages_example2, max_tokens=100, temperature=0.8)
        print("响应 2:")
        print(f"  助手: {response2['choices'][0]['message']['content']}")
        print(f"  用量: {response2['usage']}")
        
        print("\n--- 示例 3: 系统提示 (实验性,效果取决于模型微调) ---")
        messages_example3 = [
            {"role": "system", "content": "你是一位富有诗意的助手,擅长用富有创意的方式解释复杂的编程概念。"},
            {"role": "user", "content": "解释一下编程中递归的概念。"}
        ]
        response3 = gemma_chat.create(messages=messages_example3, max_tokens=150)
        print("响应 3:")
        print(f"  助手: {response3['choices'][0]['message']['content']}")
        print(f"  用量: {response3['usage']}")

        print("\n--- 示例 4: 使用 max_tokens 强制缩短响应 ---")
        messages_example4 = [
            {"role": "user", "content": "给我讲一个关于勇敢骑士的很长的故事。"}
        ]
        response4 = gemma_chat.create(messages=messages_example4, max_tokens=20) # 非常短
        print("响应 4:")
        print(f"  助手: {response4['choices'][0]['message']['content']}")
        print(f"  用量: {response4['usage']}")
        if response4['usage']['completion_tokens'] >= 20:
             print("  注意:由于 max_tokens,补全可能已被截断。")


    except Exception as e:
        print(f"示例用法期间发生错误: {e}")
        import traceback
        traceback.print_exc()