merterbak commited on
Commit
3f07255
·
verified ·
1 Parent(s): 6372c40

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
2
+ import torch
3
+ from threading import Thread
4
+ import gradio as gr
5
+ import spaces
6
+
7
+ model_id = "ByteDance-Seed/Seed-Coder-8B-Instruct"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_id,
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="auto"
13
+ ).eval()
14
+
15
+ def format_conversation_history(chat_history):
16
+ messages = []
17
+ for item in chat_history:
18
+ role = item["role"]
19
+ content = item["content"]
20
+ if isinstance(content, list):
21
+ content = content[0]["text"] if content and "text" in content[0] else str(content)
22
+ messages.append({"role": role, "content": content})
23
+ return messages
24
+
25
+ @spaces.GPU()
26
+ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
27
+ new_message = {"role": "user", "content": input_data}
28
+ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
29
+ processed_history = format_conversation_history(chat_history)
30
+ messages = system_message + processed_history + [new_message]
31
+ inputs = tokenizer.apply_chat_template(
32
+ messages,
33
+ add_generation_prompt=True,
34
+ tokenize=True,
35
+ return_tensors="pt"
36
+ ).to(model.device)
37
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
38
+ generation_kwargs = {
39
+ "input_ids": inputs,
40
+ "streamer": streamer,
41
+ "max_new_tokens": max_new_tokens,
42
+ "do_sample": True,
43
+ "temperature": temperature,
44
+ "top_p": top_p,
45
+ "top_k": top_k,
46
+ "repetition_penalty": repetition_penalty
47
+ }
48
+
49
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
50
+ thread.start()
51
+ outputs = []
52
+ for text_chunk in streamer:
53
+ outputs.append(text_chunk)
54
+ yield "".join(outputs)
55
+
56
+ demo = gr.ChatInterface(
57
+ fn=generate_response,
58
+ additional_inputs=[
59
+ gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=1024),
60
+ gr.Textbox(
61
+ label="System Prompt",
62
+ value="You are a helpful coding assistant specializing in generating accurate and efficient code.",
63
+ lines=4,
64
+ placeholder="Change system prompt"
65
+ ),
66
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
67
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
68
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
69
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
70
+ ],
71
+ examples=[
72
+ [{"text": "Develop a Python Dijkstra’s algorithm to find the shortest path between nodes in a weighted graph for a navigation app"}],
73
+ [{"text": "Write an SQL query to retrieve the top 5 most-accessed files in a cloud storage system by download count, including file type and size"}],
74
+ [{"text": "Write a JavaScript function to validate email address and telephone number using regular expressions."}],
75
+ [{"text": "Write an HTML/CSS stylesheet to style a multi-level navigation menu with hover effects and mobile compatibility"}],
76
+ ],
77
+ cache_examples=False,
78
+ type="messages",
79
+ description="""
80
+ # Seed-Coder-8B-Instruct
81
+ This model excelling in code generation, code completion, code editing and software engineering tasks and developed by ByteDance Seed team.
82
+ It pre-trained on 6 trillion token dataset supporting 89 programming languages.
83
+ """,
84
+ fill_height=True,
85
+ textbox=gr.Textbox(
86
+ label="Query Input",
87
+ placeholder="Type your prompt"
88
+ ),
89
+ stop_btn="Stop Generation",
90
+ multimodal=False,
91
+ theme=gr.themes.Soft()
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ demo.launch()