Spaces:
Running
Running
# ──────────────────────────────────────────────────────────────────────────────── | |
# app.py (CPU-only 版:先加载 float32 基座 LLaMA-8B,再叠入 LoRA Adapter) | |
# ──────────────────────────────────────────────────────────────────────────────── | |
import gradio as gr | |
import torch | |
import gc | |
import os | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
from peft import PeftModel | |
# ─────────────────────── 1. 释放可能的显存/内存 ─────────────────────── | |
# 对于 CPU-only,可以留着,也不会报错 | |
torch.cuda.empty_cache() | |
gc.collect() | |
# ─────────────────────── 2. 配置区域 ─────────────────────── | |
# (A)Adapter 仓库 ID:LoRA 权重所在的 Hugging Face Repo | |
# 这个仓库里只有 adapter_model.safetensors + adapter_config.json + tokenizer 文件 | |
ADAPTER_REPO = "yxccai/text-style-converter" | |
# (B)基座模型 ID(去掉了 -bnb-4bit 后缀,改用 float32 版) | |
# 原 adapter_config.json 里提到的 "unsloth/deepseek-r1-distill-llama-8b-unsloth-bnb-4bit" | |
# 在 CPU-only 环境下不能加载 4bit bitsandbytes,所以我们要改为: | |
# "unsloth/deepseek-r1-distill-llama-8b" | |
# 如果您本地没有这个仓库,可以换成“decapoda-research/llama-7b-hf”或其他您能在 CPU 上跑通的模型。 | |
BASE_MODEL_ID = "unsloth/deepseek-r1-distill-llama-8b" | |
# 全局变量:Tokenizer + Model | |
tokenizer = None | |
model = None | |
# ─────────────────────── 3. 加载模型的函数 ─────────────────────── | |
def load_model(): | |
""" | |
CPU-only 逻辑: | |
1. 先从 Adapter 仓库加载 Tokenizer(里面有 tokenizer.json 等文件)。 | |
2. 再用 LlamaForCausalLM 从 float32 版基座模型加载到 CPU。 | |
3. 然后用 PeftModel.from_pretrained(...) 将 LoRA Adapter 权重叠加到基座上。 | |
""" | |
global tokenizer, model | |
# 如果 tokenizer/model 还未加载,则执行加载逻辑 | |
if tokenizer is None or model is None: | |
try: | |
# ── 3.1 加载 Tokenizer ── | |
print("正在加载 Tokenizer(来自 LoRA 仓库)…") | |
tokenizer = AutoTokenizer.from_pretrained( | |
ADAPTER_REPO, | |
trust_remote_code=True, | |
use_fast=False, | |
) | |
# 如果 pad_token 不存在,就用 eos_token 代替 | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# ── 3.2 加载基座模型(LLaMA float32 → CPU) ── | |
print(f"正在加载基座模型:{BASE_MODEL_ID} (float32 → CPU)…") | |
# 注意:这里用 torch_dtype=torch.float32, device_map="cpu"。如果 Model 太大、内存不足,会 OOM。 | |
base_model = LlamaForCausalLM.from_pretrained( | |
BASE_MODEL_ID, | |
torch_dtype=torch.float32, | |
device_map="cpu", | |
low_cpu_mem_usage=True, # 尽量启用低内存占用模式 | |
trust_remote_code=True, | |
) | |
print("→ 基座模型加载完成。(注意检查是否被系统 OOM)") | |
# ── 3.3 用 PeftModel 叠加 LoRA Adapter ── | |
print(f"正在叠加 LoRA Adapter:{ADAPTER_REPO}…") | |
model = PeftModel.from_pretrained( | |
base_model, | |
ADAPTER_REPO, | |
device_map="cpu", # CPU-only 环境 | |
torch_dtype=torch.float32, # 同样使用 float32 | |
) | |
print("→ LoRA Adapter 已叠加成功。") | |
# (可选)不想更新基座所有参数时,把 base_model 的参数都冻结: | |
# model.eval() | |
# for param in model.base_model.parameters(): | |
# param.requires_grad = False | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
print(f"模型加载失败: {str(e)}") | |
return False | |
return True | |
# ─────────────────────── 4. 文本生成函数 ─────────────────────── | |
def convert_text_style(input_text: str) -> str: | |
""" | |
输入一句书面化/技术性的中文,让模型把它转换成自然、口语化的表达方式。 | |
""" | |
if not input_text or input_text.strip() == "": | |
return "请输入要转换的文本。" | |
# 确保模型已加载 | |
if not load_model(): | |
return "模型加载失败,请稍后重试。" | |
try: | |
# 拼一个简单的 Prompt | |
prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。 | |
### 输入文本: | |
{input_text} | |
### 输出文本: | |
""" | |
# 分词 & 转 torch.Tensor | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True, | |
padding=True, | |
) | |
# 全部放到 CPU 上 | |
inputs = {k: v.to("cpu") for k, v in inputs.items()} | |
# 生成 | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=256, | |
temperature=0.7, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
no_repeat_ngram_size=2, | |
num_return_sequences=1, | |
) | |
# 解码并抽取结果 | |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
if "### 输出文本:" in full_text: | |
return full_text.split("### 输出文本:")[-1].strip() | |
return full_text[len(prompt) :].strip() | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return f"生成过程中出现错误: {str(e)}" | |
# ─────────────────────── 5. Gradio 界面配置 ─────────────────────── | |
iface = gr.Interface( | |
fn=convert_text_style, | |
inputs=gr.Textbox( | |
label="输入文本", placeholder="请输入需要转换为口语化的书面文本...", lines=3 | |
), | |
outputs=gr.Textbox(label="输出文本", lines=4), | |
title="中文文本风格转换API", | |
description="将书面化、技术性文本转换为自然、口语化表达", | |
examples=[ | |
["乙醇的检测方法包括酸碱度检查。"], | |
["本品为薄膜衣片,除去包衣后显橙红色。"], | |
], | |
cache_examples=False, | |
flagging_mode="never", | |
) | |
if __name__ == "__main__": | |
print("启动 Gradio 应用…") | |
# 纯 CPU 环境下,server_name 可以保持默认 "0.0.0.0",port 也是 7860 | |
iface.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=False) | |