File size: 4,549 Bytes
ec9daa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae39785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec9daa1
ae39785
ec9daa1
 
 
 
 
 
ae39785
 
ec9daa1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
import gc
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

# 清理内存
torch.cuda.empty_cache()
gc.collect()

# 设置环境变量
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# 模型名称
model_name = "您的用户名/text-style-converter"

# 全局变量存储模型
tokenizer = None
model = None

def load_model():
    """延迟加载模型"""
    global tokenizer, model
    
    if tokenizer is None or model is None:
        try:
            print("正在加载tokenizer...")
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                use_fast=False  # 使用慢速tokenizer减少内存
            )
            
            print("正在加载模型...")
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16,  # 使用半精度
                device_map="cpu",           # 强制使用CPU
                low_cpu_mem_usage=True,     # 启用低内存模式
                trust_remote_code=True,
                load_in_8bit=False,         # 在CPU上不使用量化
                offload_folder="./offload", # 设置offload文件夹
            )
            
            # 设置pad_token
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                
            print("模型加载完成!")
            
        except Exception as e:
            print(f"模型加载失败: {str(e)}")
            return False
    
    return True

def convert_text_style(input_text):
    """文本风格转换函数"""
    if not input_text.strip():
        return "请输入要转换的文本"
    
    # 检查模型是否加载
    if not load_model():
        return "模型加载失败,请稍后重试"
    
    try:
        prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。

### 输入文本:
{input_text}

### 输出文本:
"""
        
        # 编码输入
        inputs = tokenizer(
            prompt, 
            return_tensors="pt",
            max_length=1024,  # 限制输入长度
            truncation=True,
            padding=True
        )
        
        # 生成回答
        with torch.no_grad():  # 不计算梯度节省内存
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=300,  # 减少生成长度
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                num_return_sequences=1,
                no_repeat_ngram_size=2
            )
        
        # 解码输出
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 提取生成的部分
        if "### 输出文本:" in full_response:
            response = full_response.split("### 输出文本:")[-1].strip()
        else:
            response = full_response[len(prompt):].strip()
        
        # 清理内存
        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()
        
        return response if response else "抱歉,未能生成有效回答"
        
    except Exception as e:
        return f"生成过程中出现错误: {str(e)}"

# 创建Gradio接口 - 修复版本兼容性问题
iface = gr.Interface(
    fn=convert_text_style,
    inputs=gr.Textbox(
        label="输入文本",
        placeholder="请输入需要转换为口语化的书面文本...",
        lines=3
    ),
    outputs=gr.Textbox(
        label="输出文本",
        lines=3
    ),
    title="中文文本风格转换API",
    description="将书面化、技术性文本转换为自然、口语化表达",
    examples=[
        ["乙醇的检测方法包括酸碱度检查。"],
        ["本品为薄膜衣片,除去包衣后显橙红色。"]
    ],
    cache_examples=False,  # 不缓存示例
    flagging_mode="never"  # 修复:使用flagging_mode替代allow_flagging
)

# 启动应用 - 移除不兼容的参数
if __name__ == "__main__":
    print("正在启动应用...")
    iface.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        debug=False
        # 移除了enable_queue和max_threads参数
    )