import torch # 导入 torch # 添加项目根目录到Python路径 import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from src.podcast_transcribe.llm.llm_gemma_mlx import GemmaMLXChatCompletion from src.podcast_transcribe.llm.llm_gemma_transfomers import GemmaTransformersChatCompletion if __name__ == "__main__": # 示例用法: print("正在初始化 LLM 聊天补全...") try: # model_name = "mlx-community/gemma-3-12b-it-4bit-DWQ" model_name = "google/gemma-3-4b-it" device = "cuda" # gemma_chat = GemmaMLXChatCompletion(model_name="mlx-community/gemma-3-12b-it-4bit-DWQ") # 或者,如果您有更小、更快的模型,可以尝试使用,例如:"mlx-community/gemma-2b-it-8bit" if model_name.startswith("mlx-community"): gemma_chat = GemmaMLXChatCompletion(model_name=model_name) else: # 如果设备是 mps 或 cuda,则使用 float32 以增加稳定性 gemma_chat = GemmaTransformersChatCompletion(model_name=model_name, device=device) print("\n--- 示例 1: 简单用户查询 ---") messages_example1 = [ {"role": "user", "content": "你好,你是谁?"} ] response1 = gemma_chat.create(messages=messages_example1, max_tokens=50) print("响应 1:") print(f" 助手: {response1['choices'][0]['message']['content']}") print(f" 用量: {response1['usage']}") print("\n--- 示例 2: 带历史记录的对话 ---") messages_example2 = [ {"role": "user", "content": "法国的首都是哪里?"}, {"role": "assistant", "content": "法国的首都是巴黎。"}, {"role": "user", "content": "你能告诉我一个关于它的有趣事实吗?"} ] response2 = gemma_chat.create(messages=messages_example2, max_tokens=100, temperature=0.8) print("响应 2:") print(f" 助手: {response2['choices'][0]['message']['content']}") print(f" 用量: {response2['usage']}") print("\n--- 示例 3: 系统提示 (实验性,效果取决于模型微调) ---") messages_example3 = [ {"role": "system", "content": "你是一位富有诗意的助手,擅长用富有创意的方式解释复杂的编程概念。"}, {"role": "user", "content": "解释一下编程中递归的概念。"} ] response3 = gemma_chat.create(messages=messages_example3, max_tokens=150) print("响应 3:") print(f" 助手: {response3['choices'][0]['message']['content']}") print(f" 用量: {response3['usage']}") print("\n--- 示例 4: 使用 max_tokens 强制缩短响应 ---") messages_example4 = [ {"role": "user", "content": "给我讲一个关于勇敢骑士的很长的故事。"} ] response4 = gemma_chat.create(messages=messages_example4, max_tokens=20) # 非常短 print("响应 4:") print(f" 助手: {response4['choices'][0]['message']['content']}") print(f" 用量: {response4['usage']}") if response4['usage']['completion_tokens'] >= 20: print(" 注意:由于 max_tokens,补全可能已被截断。") except Exception as e: print(f"示例用法期间发生错误: {e}") import traceback traceback.print_exc()