File size: 6,220 Bytes
4dda180
 
 
 
 
 
5f8627a
d04a7a9
4dda180
 
 
 
 
 
 
 
 
 
 
 
5f8627a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d04a7a9
5f8627a
d04a7a9
b604ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
---
base_model:
- llm-jp/llm-jp-3-3.7b-instruct2
license: apache-2.0
datasets:
- p1atdev/gsm8k-ja-slim
- SyntheticVeryEasyMath5k
- SyntheticWhichIsGreater5k
language:
- ja
library_name: transformers
tags:
- grpo
- trl
---

additional instruction:

```
回答する際は、思考過程を<think></think>ブロック内に記述し、最終的な答えを数値のみで<answer></answer>ブロック内に記述してください。
```

## Example

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("p1atdev/llm-jp-3-3.7b-instruct2-R27")
model = AutoModelForCausalLM.from_pretrained("p1atdev/llm-jp-3-3.7b-instruct2-R27", torch_dtype=torch.float16)
model = model.eval().to("cuda")


additional_instruction = "回答する際は、思考過程を<think></think>ブロック内に記述し、最終的な答えを数値のみで<answer></answer>ブロック内に記述してください。"
question = "ナタリアは4月に48人の友人にクリップを販売し、その後5月にはその半分の数のクリップを販売しました。ナタリアは4月と5月の合計でいくつのクリップを販売しましたか?"

inputs = tokenizer.apply_chat_template(
  [
    {
        "role": "user",
        "content": question
    },
  ],
  additional_instruction=additional_instruction, # pass the additional instruction
  tokenize=False,
  add_generation_prompt=True, # append "### 応答:"
  return_tensors="pt",
)
inputs = tokenizer(inputs, return_tensors="pt").to(model.device)

with torch.inference_mode():
  outputs = model.generate(
      **inputs,
      do_sample=True,
      temperature=0.9,
      top_p=0.6,
      top_k=20,
      max_new_tokens=256,
      repetition_penalty=1.0,
      eos_token_id=tokenizer.eos_token_id,
      pad_token_id=tokenizer.pad_token_id,
  )
print(tokenizer.decode(outputs[0][len(inputs.input_ids[0]):]))
```

the output:

```
<think>
4月にナタリアは48人の友人にクリップを販売しました。
5月にはその半分の数のクリップを販売したので、48 ÷ 2 = 24人の友人にクリップを販売したことになります。
したがって、4月と5月の合計でナタリアは48 + 24 = 72人の友人にクリップを販売したことになります。</think>
<answer>72</answer></s>
```


## Dataset

- 日本語訳した GSM8K ([p1atdev/gsm8k-ja-slim](https://huggingface.co/datasets/p1atdev/gsm8k-ja-slim))
  - うち、苗字や人名に関する問題を除外
- SyntheticVeryEasyMath5k
  - 機械的に合成した、整数の四則演算問題 5,000問
- SyntheticWhichIsGreater5k
  - 機械的に合成した、二つの小数のどちらが大きいかを回答する問題 5,000問
 
下の二つのデータは、[math_problem.py](https://huggingface.co/p1atdev/llm-jp-3-3.7b-instruct2-R27/blob/main/math_problem.py) の関数を使って以下のように合成しました。

```py
def generate_int_problem(
    num_generation: int,
    max_int: int,
    min_int: int,
    max_terms: int,
):
    for i in range(num_generation):
        text, tex, result = create_integer_arithmetic_problem(
            max_val=max_int,
            min_val=min_int,
            max_terms=max_terms,
        )
        formula = random.choice([text, tex])
        templates = [
            "{formula} = ?",
            "{formula} を計算してください。",
            "次の式を計算し、計算結果を解答してください。\n{formula}",
            "計算して\n{formula}",
            "次の式を計算してください。\n{formula}",
            "次の式の答えは何ですか?\n{formula}",
            "? に当てはまる数字を答えてください。\n{formula}",
            "{formula}\n計算して",
            "{formula}\n↑の答えを求めてください。",
        ]
        instruction = random.choice(templates).format(formula=formula)

        yield {
            "ground_truth": str(result),
            "instruction": instruction,
            "source": "synthetic_int_problem",
            "answer_dtype": "int",
            "skip_check": False,
        }

def generate_wig_problem(
    num_generation: int,
    max_num: float,
    min_num: float,
    precision: int,
):
    for i in range(num_generation):
        num_1, num_2, greater = create_two_decimals(
            min_val=min_num,
            max_val=max_num,
            precision=precision,
        )
        templates = [
            "次の数字のうち、どちらが大きいですか?\n{num_1}\n{num_2}",
            "{num_1} と {num_2} のうちどちらが大きいですか?",
            "{num_1} と {num_2} はどっちが大きい?",
            "大きいほうを選んで: {num_1} {num_2}",
            "次の数値を比較し、大きい方を選んでください。\n{num_1} {num_2}",
        ]
        instruction = random.choice(templates).format(num_1=num_1, num_2=num_2)

        yield {
            "ground_truth": str(greater),
            "instruction": instruction,
            "source": "synthetic_which_is_greater",
            "answer_dtype": "float",
            "skip_check": False,
        }


# generate dataset
ds_easy_int = Dataset.from_generator(
    generate_int_problem,
    gen_kwargs={
        "num_generation": 5000,
        "max_int": 10,
        "min_int": -10,
        "max_terms": 5,
    },
)
assert isinstance(ds_easy_int, Dataset)
print("easy_int:", ds_easy_int)

ds_wig = Dataset.from_generator(
    generate_wig_problem,
    gen_kwargs={
        "num_generation": 5000,
        "max_num": 100.0,
        "min_num": -100.0,
        "precision": 3,
    },
)
assert isinstance(ds_wig, Dataset)
print("wig:", ds_wig)

# japanese gsm8k
ds_gsm8k = load_dataset("p1atdev/gsm8k-ja-slim", split="train")
assert isinstance(ds_gsm8k, Dataset)
ds_gsm8k = ds_gsm8k.filter(filter_gsm8k_ja, batched=True)
ds_gsm8k = ds_gsm8k.map(
    map_gsm8k_ja_instruction, remove_columns=ds_gsm8k.column_names
)
print("gsm8k:", ds_gsm8k)

# concat
ds = concatenate_datasets( # from datasets import concatenate_datasets
    [
        ds_gsm8k,
        ds_easy_int,
        ds_wig,
    ]
)
print("total:", ds)
```