Spaces:
Running
Running
File size: 4,549 Bytes
ec9daa1 ae39785 ec9daa1 ae39785 ec9daa1 ae39785 ec9daa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# 清理内存
torch.cuda.empty_cache()
gc.collect()
# 设置环境变量
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# 模型名称
model_name = "您的用户名/text-style-converter"
# 全局变量存储模型
tokenizer = None
model = None
def load_model():
"""延迟加载模型"""
global tokenizer, model
if tokenizer is None or model is None:
try:
print("正在加载tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_fast=False # 使用慢速tokenizer减少内存
)
print("正在加载模型...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # 使用半精度
device_map="cpu", # 强制使用CPU
low_cpu_mem_usage=True, # 启用低内存模式
trust_remote_code=True,
load_in_8bit=False, # 在CPU上不使用量化
offload_folder="./offload", # 设置offload文件夹
)
# 设置pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("模型加载完成!")
except Exception as e:
print(f"模型加载失败: {str(e)}")
return False
return True
def convert_text_style(input_text):
"""文本风格转换函数"""
if not input_text.strip():
return "请输入要转换的文本"
# 检查模型是否加载
if not load_model():
return "模型加载失败,请稍后重试"
try:
prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。
### 输入文本:
{input_text}
### 输出文本:
"""
# 编码输入
inputs = tokenizer(
prompt,
return_tensors="pt",
max_length=1024, # 限制输入长度
truncation=True,
padding=True
)
# 生成回答
with torch.no_grad(): # 不计算梯度节省内存
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=300, # 减少生成长度
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_return_sequences=1,
no_repeat_ngram_size=2
)
# 解码输出
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取生成的部分
if "### 输出文本:" in full_response:
response = full_response.split("### 输出文本:")[-1].strip()
else:
response = full_response[len(prompt):].strip()
# 清理内存
del inputs, outputs
torch.cuda.empty_cache()
gc.collect()
return response if response else "抱歉,未能生成有效回答"
except Exception as e:
return f"生成过程中出现错误: {str(e)}"
# 创建Gradio接口 - 修复版本兼容性问题
iface = gr.Interface(
fn=convert_text_style,
inputs=gr.Textbox(
label="输入文本",
placeholder="请输入需要转换为口语化的书面文本...",
lines=3
),
outputs=gr.Textbox(
label="输出文本",
lines=3
),
title="中文文本风格转换API",
description="将书面化、技术性文本转换为自然、口语化表达",
examples=[
["乙醇的检测方法包括酸碱度检查。"],
["本品为薄膜衣片,除去包衣后显橙红色。"]
],
cache_examples=False, # 不缓存示例
flagging_mode="never" # 修复:使用flagging_mode替代allow_flagging
)
# 启动应用 - 移除不兼容的参数
if __name__ == "__main__":
print("正在启动应用...")
iface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False
# 移除了enable_queue和max_threads参数
) |