yxccai commited on
Commit
ec9daa1
·
verified ·
1 Parent(s): bf37ace

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -71
app.py CHANGED
@@ -1,71 +1,153 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
-
5
- # 加载模型和tokenizer
6
- model_name = "yxccai/text-style-converter"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
- torch_dtype=torch.float16,
11
- device_map="auto"
12
- )
13
-
14
- def convert_text_style(input_text):
15
- """文本风格转换函数"""
16
- if not input_text.strip():
17
- return "请输入要转换的文本"
18
-
19
- prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。
20
-
21
- ### 输入文本:
22
- {input_text}
23
-
24
- ### 输出文本:
25
- """
26
-
27
- inputs = tokenizer(prompt, return_tensors="pt")
28
-
29
- with torch.no_grad():
30
- outputs = model.generate(
31
- inputs.input_ids,
32
- attention_mask=inputs.attention_mask,
33
- max_new_tokens=500,
34
- temperature=0.7,
35
- do_sample=True,
36
- pad_token_id=tokenizer.eos_token_id
37
- )
38
-
39
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
-
41
- # 提取生成的部分
42
- if "### 输出文本:" in full_response:
43
- response = full_response.split("### 输出文本:")[-1].strip()
44
- else:
45
- response = full_response
46
-
47
- return response
48
-
49
- # 创建Gradio接口
50
- iface = gr.Interface(
51
- fn=convert_text_style,
52
- inputs=gr.Textbox(
53
- label="输入文本",
54
- placeholder="请输入需要转换为口语化的书面文本...",
55
- lines=5
56
- ),
57
- outputs=gr.Textbox(
58
- label="输出文本",
59
- lines=5
60
- ),
61
- title="中文文本风格转换API",
62
- description="将书面化、技术性文本转换为自然、口语化表达",
63
- examples=[
64
- ["乙醇的检测方法包括以下几项: 1. 酸碱度检查:取20ml乙醇加20ml水,加2滴酚酞指示剂应无色。"],
65
- ["本品为薄膜衣片,除去包衣后显橙红色至暗红色。"]
66
- ]
67
- )
68
-
69
- # 启动应用
70
- if __name__ == "__main__":
71
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import gc
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import os
6
+
7
+ # 清理内存
8
+ torch.cuda.empty_cache()
9
+ gc.collect()
10
+
11
+ # 设置环境变量
12
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
13
+
14
+ # 模型名称
15
+ model_name = "您的用户名/text-style-converter"
16
+
17
+ # 全局变量存储模型
18
+ tokenizer = None
19
+ model = None
20
+
21
+ def load_model():
22
+ """延迟加载模型"""
23
+ global tokenizer, model
24
+
25
+ if tokenizer is None or model is None:
26
+ try:
27
+ print("正在加载tokenizer...")
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name,
30
+ trust_remote_code=True,
31
+ use_fast=False # 使用慢速tokenizer减少内存
32
+ )
33
+
34
+ print("正在加载模型...")
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ torch_dtype=torch.float16, # 使用半精度
38
+ device_map="cpu", # 强制使用CPU
39
+ low_cpu_mem_usage=True, # 启用低内存模式
40
+ trust_remote_code=True,
41
+ load_in_8bit=False, # 在CPU上不使用量化
42
+ offload_folder="./offload", # 设置offload文件夹
43
+ )
44
+
45
+ # 设置pad_token
46
+ if tokenizer.pad_token is None:
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+
49
+ print("模型加载完成!")
50
+
51
+ except Exception as e:
52
+ print(f"模型加载失败: {str(e)}")
53
+ return False
54
+
55
+ return True
56
+
57
+ def convert_text_style(input_text):
58
+ """文本风格转换函数"""
59
+ if not input_text.strip():
60
+ return "请输入要转换的文本"
61
+
62
+ # 检查模型是否加载
63
+ if not load_model():
64
+ return "模型加载失败,请稍后重试"
65
+
66
+ try:
67
+ prompt = f"""以下是一个文本风格转换任务,请将书面化、技术性的输入文本转换为自然、口语化的表达方式。
68
+
69
+ ### 输入文本:
70
+ {input_text}
71
+
72
+ ### 输出文本:
73
+ """
74
+
75
+ # 编码输入
76
+ inputs = tokenizer(
77
+ prompt,
78
+ return_tensors="pt",
79
+ max_length=1024, # 限制输入长度
80
+ truncation=True,
81
+ padding=True
82
+ )
83
+
84
+ # 生成回答
85
+ with torch.no_grad(): # 不计算梯度节省内存
86
+ outputs = model.generate(
87
+ inputs.input_ids,
88
+ attention_mask=inputs.attention_mask,
89
+ max_new_tokens=300, # 减少生成长度
90
+ temperature=0.7,
91
+ do_sample=True,
92
+ pad_token_id=tokenizer.eos_token_id,
93
+ eos_token_id=tokenizer.eos_token_id,
94
+ num_return_sequences=1,
95
+ no_repeat_ngram_size=2
96
+ )
97
+
98
+ # 解码输出
99
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+
101
+ # 提取生成的部分
102
+ if "### 输出文本:" in full_response:
103
+ response = full_response.split("### 输出文本:")[-1].strip()
104
+ else:
105
+ response = full_response[len(prompt):].strip()
106
+
107
+ # 清理内存
108
+ del inputs, outputs
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+
112
+ return response if response else "抱歉,未能生成有效回答"
113
+
114
+ except Exception as e:
115
+ return f"生成过程中出现错误: {str(e)}"
116
+
117
+ # 创建Gradio接口
118
+ def create_interface():
119
+ iface = gr.Interface(
120
+ fn=convert_text_style,
121
+ inputs=gr.Textbox(
122
+ label="输入文本",
123
+ placeholder="请输入需要转换为口语化的书面文本...",
124
+ lines=3
125
+ ),
126
+ outputs=gr.Textbox(
127
+ label="输出文本",
128
+ lines=3
129
+ ),
130
+ title="中文文本风格转换API",
131
+ description="将书面化、技术性文本转换为自然、口语化表达",
132
+ examples=[
133
+ ["乙醇的检测方法包括酸碱度检查。"],
134
+ ["本品为薄膜衣片,除去包衣后显橙红色。"]
135
+ ],
136
+ cache_examples=False, # 不缓存示例
137
+ allow_flagging="never" # 禁用标记功能
138
+ )
139
+
140
+ return iface
141
+
142
+ # 启动应用
143
+ if __name__ == "__main__":
144
+ print("正在启动应用...")
145
+ iface = create_interface()
146
+ iface.launch(
147
+ server_name="0.0.0.0",
148
+ server_port=7860,
149
+ share=False,
150
+ debug=False,
151
+ enable_queue=True,
152
+ max_threads=1 # 限制线程数
153
+ )