username commited on
Commit
06d92b4
·
1 Parent(s): 457d89f
Files changed (1) hide show
  1. app.py +30 -25
app.py CHANGED
@@ -1,49 +1,54 @@
1
  import torch
2
  import gradio as gr
3
  import spaces
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
- tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-8x1.8b-instruct3")
7
- model = None
 
 
 
 
 
8
 
9
  @spaces.GPU
10
  def generate_text(system_prompt, user_input, max_length=512, temperature=0.7, top_p=0.95):
11
- global model
12
- if model is None:
13
- model = AutoModelForCausalLM.from_pretrained(
14
- "llm-jp/llm-jp-3-8x1.8b-instruct3",
15
- device_map="auto",
16
- torch_dtype=torch.bfloat16
17
- )
18
-
19
  chat = [
20
  {"role": "system", "content": system_prompt},
21
  {"role": "user", "content": user_input},
22
  ]
23
 
24
- tokenized_input = tokenizer.apply_chat_template(
25
  chat,
26
  add_generation_prompt=True,
27
- tokenize=True,
28
  return_tensors="pt"
29
  ).to(model.device)
30
 
31
- with torch.no_grad():
32
- output = model.generate(
33
- tokenized_input,
34
- max_new_tokens=max_length,
35
- do_sample=True,
36
- top_p=top_p,
37
- temperature=temperature,
38
- repetition_penalty=1.05,
39
- )[0]
 
 
 
 
 
 
 
 
 
40
 
41
- generated_text = tokenizer.decode(output, skip_special_tokens=True)
42
- return generated_text
43
 
44
  with gr.Blocks() as demo:
45
  gr.Markdown("# LLM-JP-3-8x1.8b-instruct3 非公式デモ")
46
- gr.Markdown("国立情報学研究所大規模言語モデル研究開発センターの開発した日本語大規模言語モデル「LLM-JP-3」の非公式デモ。詳細は[こちらの記事](https://llm-jp.nii.ac.jp/blog/2025/03/27/moe3.html)をご覧ください。ZeroGPU を使用しています。")
47
 
48
  with gr.Row():
49
  with gr.Column():
 
1
  import torch
2
  import gradio as gr
3
  import spaces
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ from threading import Thread
6
 
7
+ model_id = "llm-jp/llm-jp-3-8x1.8b-instruct3"
8
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ model_id,
11
+ device_map="auto",
12
+ torch_dtype=torch.bfloat16
13
+ )
14
 
15
  @spaces.GPU
16
  def generate_text(system_prompt, user_input, max_length=512, temperature=0.7, top_p=0.95):
 
 
 
 
 
 
 
 
17
  chat = [
18
  {"role": "system", "content": system_prompt},
19
  {"role": "user", "content": user_input},
20
  ]
21
 
22
+ input_ids = tokenizer.apply_chat_template(
23
  chat,
24
  add_generation_prompt=True,
 
25
  return_tensors="pt"
26
  ).to(model.device)
27
 
28
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
29
+
30
+ generate_kwargs = {
31
+ "input_ids": input_ids,
32
+ "streamer": streamer,
33
+ "max_new_tokens": max_length,
34
+ "do_sample": True,
35
+ "temperature": temperature,
36
+ "top_p": top_p,
37
+ "repetition_penalty": 1.05
38
+ }
39
+
40
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
41
+ thread.start()
42
+
43
+ response = ""
44
+ for text in streamer:
45
+ response += text
46
 
47
+ return response
 
48
 
49
  with gr.Blocks() as demo:
50
  gr.Markdown("# LLM-JP-3-8x1.8b-instruct3 非公式デモ")
51
+ gr.Markdown("国立情報学研究所大規模言語モデル研究開発センター(LLMC)が開発した日本語大規模言語モデル LLM-jp-3 MoE 8x1.8B の非公式デモです。ZeroGPU を使用しています。")
52
 
53
  with gr.Row():
54
  with gr.Column():