File size: 2,157 Bytes
50ac931
e719f54
 
 
 
d4a97a6
94a08bb
e719f54
 
 
 
 
 
 
 
 
d4a97a6
e719f54
 
 
 
 
 
 
d4a97a6
e719f54
 
 
 
 
 
d4a97a6
 
 
0f9d18a
d4a97a6
 
 
d7abcd1
 
 
 
 
 
 
5c8f89e
d7abcd1
 
 
5c8f89e
d7abcd1
 
 
5c8f89e
d7abcd1
 
 
 
 
 
 
5c8f89e
d7abcd1
 
50ac931
 
e719f54
d4a97a6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import spaces

# 加载模型和分词器
model_name = "Zhihu-ai/Zhi-writing-dsr1-14b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
@spaces.GPU()
def predict(message, history):
    # 构建输入
    history_text = ""
    for human, assistant in history:
        history_text += f"Human: {human}\nAssistant: {assistant}\n"
    prompt = f"{history_text}Human: {message}\nAssistant:"
    
    # 生成回复
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=10000,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id
    )
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    
    return response.strip()

# 创建Gradio界面
demo = gr.ChatInterface(
    predict,
    title="测试Zhi-writing-dsr1-14b",
    description="Zhihu-ai/Zhi-writing-dsr1-14b",
    examples=["鲁迅口吻写五百字,描述桔猫的可爱!", "桔了个仔是谁", "介绍自己"],
    theme=gr.themes.Soft()
)

# with gr.Blocks(theme=gr.themes.Soft()) as demo:
#     gr.Markdown("# Zhi-writing-dsr1-14")
#     gr.Markdown("这是一个基于Zhi-writing-dsr1-14的文章生成器")
    
#     chatbot = gr.Chatbot()
#     msg = gr.Textbox(label="输入消息")
#     clear = gr.Button("清除对话")
    
#     def respond(message, chat_history):
#         bot_message = ""
#         for response in predict(message, chat_history):
#             bot_message = response
#             chat_history.append((message, bot_message))
#             yield chat_history
#         return "", chat_history
    
#     msg.submit(respond, [msg, chatbot], [msg, chatbot])
#     clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(share=True)