File size: 7,522 Bytes
7460c5d
 
 
 
ec9daa1
 
 
 
7460c5d
 
ec9daa1
7460c5d
 
ec9daa1
 
 
7460c5d
 
 
 
 
ec9daa1
7460c5d
 
 
 
 
 
ec9daa1
7460c5d
ec9daa1
 
 
7460c5d
ec9daa1
7460c5d
 
 
 
 
 
ec9daa1
7460c5d
 
ec9daa1
 
7460c5d
 
ec9daa1
7460c5d
ec9daa1
7460c5d
ec9daa1
7460c5d
ec9daa1
 
7460c5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec9daa1
7460c5d
 
ec9daa1
 
7460c5d
ec9daa1
 
7460c5d
 
 
 
 
 
 
 
 
 
ec9daa1
7460c5d
 
ec9daa1
7460c5d
ec9daa1
 
 
 
 
 
 
7460c5d
 
ec9daa1
7460c5d
ec9daa1
7460c5d
ec9daa1
7460c5d
ec9daa1
7460c5d
 
 
 
 
ec9daa1
7460c5d
 
 
ec9daa1
 
 
 
7460c5d
ec9daa1
 
7460c5d
 
 
 
 
 
 
ec9daa1
7460c5d
 
ec9daa1
 
7460c5d
 
ae39785
 
 
7460c5d
ae39785
7460c5d
ae39785
 
 
 
7460c5d
ae39785
7460c5d
 
ae39785
ec9daa1
 
7460c5d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# ────────────────────────────────────────────────────────────────────────────────
# app.py (CPU-only 版:先加载 float32 基座 LLaMA-8B,再叠入 LoRA Adapter)
# ────────────────────────────────────────────────────────────────────────────────

import gradio as gr
import torch
import gc
import os
from transformers import AutoTokenizer, LlamaForCausalLM
from peft import PeftModel

# ─────────────────────── 1. 释放可能的显存/内存 ───────────────────────
# 对于 CPU-only,可以留着,也不会报错
torch.cuda.empty_cache()
gc.collect()

# ─────────────────────── 2. 配置区域 ───────────────────────

# (A)Adapter 仓库 ID:LoRA 权重所在的 Hugging Face Repo
#     这个仓库里只有 adapter_model.safetensors + adapter_config.json + tokenizer 文件
ADAPTER_REPO = "yxccai/text-style-converter"

# (B)基座模型 ID(去掉了 -bnb-4bit 后缀,改用 float32 版)
#     原 adapter_config.json 里提到的 "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit"
#     在 CPU-only 环境下不能加载 4bit bitsandbytes,所以我们要改为:
#     "unsloth/deepseek-r1-distill-llama-8b"
#     如果您本地没有这个仓库,可以换成“decapoda-research/llama-7b-hf”或其他您能在 CPU 上跑通的模型。
BASE_MODEL_ID = "unsloth/deepseek-r1-distill-llama-8b"

# 全局变量:Tokenizer + Model
tokenizer = None
model = None

# ─────────────────────── 3. 加载模型的函数 ───────────────────────
def load_model():
    """
    CPU-only 逻辑: 
    1. 先从 Adapter 仓库加载 Tokenizer(里面有 tokenizer.json 等文件)。
    2. 再用 LlamaForCausalLM 从 float32 版基座模型加载到 CPU。
    3. 然后用 PeftModel.from_pretrained(...) 将 LoRA Adapter 权重叠加到基座上。
    """
    global tokenizer, model

    # 如果 tokenizer/model 还未加载,则执行加载逻辑
    if tokenizer is None or model is None:
        try:
            # ── 3.1 加载 Tokenizer ──
            print("正在加载 Tokenizer(来自 LoRA 仓库)…")
            tokenizer = AutoTokenizer.from_pretrained(
                ADAPTER_REPO,
                trust_remote_code=True,
                use_fast=False,
            )
            # 如果 pad_token 不存在,就用 eos_token 代替
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            # ── 3.2 加载基座模型(LLaMA float32 → CPU) ──
            print(f"正在加载基座模型:{BASE_MODEL_ID} (float32 → CPU)…")
            # 注意:这里用 torch_dtype=torch.float32, device_map="cpu"。如果 Model 太大、内存不足,会 OOM。
            base_model = LlamaForCausalLM.from_pretrained(
                BASE_MODEL_ID,
                torch_dtype=torch.float32,
                device_map="cpu",
                low_cpu_mem_usage=True,       # 尽量启用低内存占用模式
                trust_remote_code=True,
            )
            print("→ 基座模型加载完成。(注意检查是否被系统 OOM)")

            # ── 3.3 用 PeftModel 叠加 LoRA Adapter ──
            print(f"正在叠加 LoRA Adapter:{ADAPTER_REPO}…")
            model = PeftModel.from_pretrained(
                base_model,
                ADAPTER_REPO,
                device_map="cpu",            # CPU-only 环境
                torch_dtype=torch.float32,   # 同样使用 float32
            )
            print("→ LoRA Adapter 已叠加成功。")

            # (可选)不想更新基座所有参数时,把 base_model 的参数都冻结:
            # model.eval()
            # for param in model.base_model.parameters():
            #     param.requires_grad = False

        except Exception as e:
            import traceback
            traceback.print_exc()
            print(f"模型加载失败: {str(e)}")
            return False

    return True


# ─────────────────────── 4. 文本生成函数 ───────────────────────
def convert_text_style(input_text: str) -> str:
    """
    输入一句书面化/技术性的中文,让模型把它转换成自然、口语化的表达方式。
    """
    if not input_text or input_text.strip() == "":
        return "请输入要转换的文本。"

    # 确保模型已加载
    if not load_model():
        return "模型加载失败,请稍后重试。"

    try:
        # 拼一个简单的 Prompt
        prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。

### 输入文本:
{input_text}

### 输出文本:
"""

        # 分词 & 转 torch.Tensor
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            max_length=1024,
            truncation=True,
            padding=True,
        )
        # 全部放到 CPU 上
        inputs = {k: v.to("cpu") for k, v in inputs.items()}

        # 生成
        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=256,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                no_repeat_ngram_size=2,
                num_return_sequences=1,
            )

        # 解码并抽取结果
        full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "### 输出文本:" in full_text:
            return full_text.split("### 输出文本:")[-1].strip()
        return full_text[len(prompt) :].strip()

    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"生成过程中出现错误: {str(e)}"


# ─────────────────────── 5. Gradio 界面配置 ───────────────────────
iface = gr.Interface(
    fn=convert_text_style,
    inputs=gr.Textbox(
        label="输入文本", placeholder="请输入需要转换为口语化的书面文本...", lines=3
    ),
    outputs=gr.Textbox(label="输出文本", lines=4),
    title="中文文本风格转换API",
    description="将书面化、技术性文本转换为自然、口语化表达",
    examples=[
        ["乙醇的检测方法包括酸碱度检查。"],
        ["本品为薄膜衣片,除去包衣后显橙红色。"],
    ],
    cache_examples=False,
    flagging_mode="never",
)

if __name__ == "__main__":
    print("启动 Gradio 应用…")
    # 纯 CPU 环境下,server_name 可以保持默认 "0.0.0.0",port 也是 7860
    iface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=False)