text-style-api / app.py
yxccai's picture
Update app.py
ec9daa1 verified
raw
history blame
4.64 kB
import gradio as gr
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# 清理内存
torch.cuda.empty_cache()
gc.collect()
# 设置环境变量
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# 模型名称
model_name = "您的用户名/text-style-converter"
# 全局变量存储模型
tokenizer = None
model = None
def load_model():
"""延迟加载模型"""
global tokenizer, model
if tokenizer is None or model is None:
try:
print("正在加载tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=False # 使用慢速tokenizer减少内存
)
print("正在加载模型...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # 使用半精度
device_map="cpu", # 强制使用CPU
low_cpu_mem_usage=True, # 启用低内存模式
trust_remote_code=True,
load_in_8bit=False, # 在CPU上不使用量化
offload_folder="./offload", # 设置offload文件夹
)
# 设置pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("模型加载完成!")
except Exception as e:
print(f"模型加载失败: {str(e)}")
return False
return True
def convert_text_style(input_text):
"""文本风格转换函数"""
if not input_text.strip():
return "请输入要转换的文本"
# 检查模型是否加载
if not load_model():
return "模型加载失败,请稍后重试"
try:
prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。
### 输入文本:
{input_text}
### 输出文本:
"""
# 编码输入
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024, # 限制输入长度
truncation=True,
padding=True
)
# 生成回答
with torch.no_grad(): # 不计算梯度节省内存
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=300, # 减少生成长度
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_return_sequences=1,
no_repeat_ngram_size=2
)
# 解码输出
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取生成的部分
if "### 输出文本:" in full_response:
response = full_response.split("### 输出文本:")[-1].strip()
else:
response = full_response[len(prompt):].strip()
# 清理内存
del inputs, outputs
torch.cuda.empty_cache()
gc.collect()
return response if response else "抱歉,未能生成有效回答"
except Exception as e:
return f"生成过程中出现错误: {str(e)}"
# 创建Gradio接口
def create_interface():
iface = gr.Interface(
fn=convert_text_style,
inputs=gr.Textbox(
label="输入文本",
placeholder="请输入需要转换为口语化的书面文本...",
lines=3
),
outputs=gr.Textbox(
label="输出文本",
lines=3
),
title="中文文本风格转换API",
description="将书面化、技术性文本转换为自然、口语化表达",
examples=[
["乙醇的检测方法包括酸碱度检查。"],
["本品为薄膜衣片,除去包衣后显橙红色。"]
],
cache_examples=False, # 不缓存示例
allow_flagging="never" # 禁用标记功能
)
return iface
# 启动应用
if __name__ == "__main__":
print("正在启动应用...")
iface = create_interface()
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
enable_queue=True,
max_threads=1 # 限制线程数
)