arnocandel pseudotensor commited on
Commit
c7b6f0f
·
0 Parent(s):

Duplicate from h2oai/h2ogpt-chatbot

Browse files

Co-authored-by: Jonathan McKinney <pseudotensor@users.noreply.huggingface.co>

Files changed (11) hide show
  1. .gitattributes +34 -0
  2. LICENSE +201 -0
  3. README.md +14 -0
  4. app.py +1573 -0
  5. client_test.py +121 -0
  6. finetune.py +930 -0
  7. h2o-logo.svg +1 -0
  8. prompter.py +106 -0
  9. requirements.txt +48 -0
  10. stopping.py +139 -0
  11. utils.py +89 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: H2ogpt Chatbot
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: h2oai/h2ogpt-chatbot
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import sys
4
+ import os
5
+ import traceback
6
+ import typing
7
+
8
+ from utils import set_seed, flatten_list, clear_torch_cache, system_info_print
9
+
10
+ SEED = 1236
11
+ set_seed(SEED)
12
+
13
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
14
+ from typing import Union
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ import fire
19
+ import torch
20
+ from peft import PeftModel
21
+ from transformers import GenerationConfig, StoppingCriteriaList, AutoModel
22
+ from accelerate import init_empty_weights, infer_auto_device_map
23
+
24
+ from prompter import Prompter
25
+
26
+ from finetune import get_loaders, example_data_points, generate_prompt, get_githash, prompt_types_strings, \
27
+ human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
28
+ from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
29
+
30
+
31
+ def main(
32
+ load_8bit: bool = False,
33
+ load_half: bool = True,
34
+ infer_devices: bool = True,
35
+ base_model: str = '',
36
+ tokenizer_base_model: str = '',
37
+ lora_weights: str = "",
38
+ force_1_gpu: bool = True,
39
+
40
+ prompt_type: Union[int, str] = None,
41
+ # input to generation
42
+ temperature: float = None,
43
+ top_p: float = None,
44
+ top_k: int = None,
45
+ num_beams: int = None,
46
+ repetition_penalty: float = None,
47
+ num_return_sequences: int = None,
48
+ do_sample: bool = None,
49
+ max_new_tokens: int = None,
50
+ min_new_tokens: int = None,
51
+ early_stopping: Union[bool, str] = None,
52
+ max_time: float = None,
53
+
54
+ llama_type: bool = None,
55
+ debug: bool = False,
56
+ share: bool = True,
57
+ local_files_only: bool = False,
58
+ resume_download: bool = True,
59
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
60
+
61
+ src_lang: str = "English",
62
+ tgt_lang: str = "Russian",
63
+
64
+ gradio: bool = True,
65
+ gradio_avoid_processing_markdown: bool = False,
66
+ chat: bool = True,
67
+ chat_history: int = 4096, # character length of chat context/history
68
+ stream_output: bool = True,
69
+ show_examples: bool = None,
70
+ verbose: bool = False,
71
+ h2ocolors: bool = True,
72
+ height: int = 400,
73
+ show_lora: bool = True,
74
+ # set to True to load --base_model after client logs in,
75
+ # to be able to free GPU memory when model is swapped
76
+ login_mode_if_model0: bool = False,
77
+
78
+ sanitize_user_prompt: bool = True,
79
+ sanitize_bot_response: bool = True,
80
+
81
+ extra_model_options: typing.List[str] = [],
82
+ extra_lora_options: typing.List[str] = [],
83
+
84
+ score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
85
+ auto_score: bool = True,
86
+
87
+ eval_sharegpt_prompts_only: int = 0,
88
+ eval_sharegpt_prompts_only_seed: int = 1234,
89
+ eval_sharegpt_as_output: bool = False,
90
+ ):
91
+ # allow set token directly
92
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
93
+ # override share if in spaces
94
+ if os.environ.get("HUGGINGFACE_SPACES"):
95
+ share = False
96
+ base_model = 'h2oai/h2ogpt-oasst1-512-12b'
97
+ load_8bit = True
98
+ temperature = 0.7
99
+ top_p = 1
100
+ top_k = 100
101
+ do_sample = True
102
+
103
+ # get defaults
104
+ model_lower = base_model.lower()
105
+ if not gradio:
106
+ # force, else not single response like want to look at
107
+ stream_output = False
108
+ # else prompt removal can mess up output
109
+ chat = False
110
+
111
+ placeholder_instruction, placeholder_input, \
112
+ stream_output, show_examples, \
113
+ prompt_type, temperature, top_p, top_k, num_beams, \
114
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
115
+ repetition_penalty, num_return_sequences, \
116
+ do_sample, \
117
+ src_lang, tgt_lang, \
118
+ examples, \
119
+ task_info = \
120
+ get_generate_params(model_lower, chat,
121
+ stream_output, show_examples,
122
+ prompt_type, temperature, top_p, top_k, num_beams,
123
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
124
+ repetition_penalty, num_return_sequences,
125
+ do_sample,
126
+ )
127
+
128
+ if not gradio:
129
+ if eval_sharegpt_prompts_only > 0:
130
+ # override default examples with shareGPT ones for human-level eval purposes only
131
+ filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
132
+ if not os.path.isfile(filename):
133
+ os.system('wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
134
+ import json
135
+ data = json.load(open(filename, 'rt'))
136
+ # focus on data that starts with human, else likely chopped from other data
137
+ turn_start = 0 # odd in general
138
+ data = [x for x in data if len(x['conversations']) > turn_start + 1 and
139
+ x['conversations'][turn_start]['from'] == 'human' and
140
+ x['conversations'][turn_start + 1]['from'] == 'gpt']
141
+ np.random.seed(eval_sharegpt_prompts_only_seed)
142
+ example1 = examples[-1] # pick reference example
143
+ examples = []
144
+ responses = []
145
+ for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
146
+ assert data[i]['conversations'][turn_start]['from'] == 'human'
147
+ instruction = data[i]['conversations'][turn_start]['value']
148
+ assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
149
+ output = data[i]['conversations'][turn_start + 1]['value']
150
+ examplenew = example1.copy()
151
+ examplenew[0] = instruction
152
+ examplenew[1] = '' # no input
153
+ examplenew[2] = '' # no context
154
+ examples.append(examplenew)
155
+ responses.append(output)
156
+
157
+ with torch.device("cuda"):
158
+ # ensure was set right above before examples generated
159
+ assert not stream_output, "stream_output=True does not make sense with example loop"
160
+ import time
161
+ from functools import partial
162
+
163
+ # get score model
164
+ smodel, stokenizer, sdevice = get_score_model(**locals())
165
+
166
+ if not eval_sharegpt_as_output:
167
+ model, tokenizer, device = get_model(**locals())
168
+ model_state = [model, tokenizer, device, base_model]
169
+ fun = partial(evaluate, model_state, debug=debug, chat=chat)
170
+ else:
171
+ assert eval_sharegpt_prompts_only > 0
172
+
173
+ def get_response(*args, exi=0):
174
+ # assumes same ordering of examples and responses
175
+ yield responses[exi]
176
+
177
+ fun = get_response
178
+ t0 = time.time()
179
+ score_dump = []
180
+ num_examples = len(examples)
181
+
182
+ import matplotlib.pyplot as plt
183
+
184
+ for exi, ex in enumerate(examples):
185
+ clear_torch_cache()
186
+ print("")
187
+ print("START" + "=" * 100)
188
+ print("Question: %s %s" % (ex[0], ('input=%s' % ex[1] if ex[1] else '')))
189
+ print("-" * 105)
190
+ # fun yields as generator, so have to iterate over it
191
+ # Also means likely do NOT want --stream_output=True, else would show all generations
192
+ for res in fun(*tuple(ex), exi=exi):
193
+ print(res)
194
+ if smodel:
195
+ score_with_prompt = False
196
+ if score_with_prompt:
197
+ data_point = dict(instruction=ex[0], input=ex[1])
198
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
199
+ prompt = prompter.generate_prompt(data_point)
200
+ else:
201
+ # just raw input and output
202
+ assert ex[1] in [None, ''] # should be no iinput
203
+ assert ex[2] in [None, ''] # should be no context
204
+ prompt = ex[0]
205
+ cutoff_len = 768 if os.environ.get("HUGGINGFACE_SPACES") else 2048
206
+ inputs = stokenizer(prompt, res,
207
+ return_tensors="pt",
208
+ truncation=True,
209
+ max_length=cutoff_len)
210
+ try:
211
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
212
+ except torch.cuda.OutOfMemoryError as e:
213
+ print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
214
+ traceback.print_exc()
215
+ score = 0.0
216
+ clear_torch_cache()
217
+ except RuntimeError as e:
218
+ if 'Expected all tensors to be on the same device' in str(
219
+ e) or 'expected scalar type Half but found Float' in str(e):
220
+ print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
221
+ flush=True)
222
+ traceback.print_exc()
223
+ score = 0.0
224
+ clear_torch_cache()
225
+ else:
226
+ raise
227
+ print("SCORE %s: %s" % (exi, score), flush=True)
228
+ score_dump.append(ex + [prompt, res, score])
229
+ # dump every score in case abort
230
+ scoring_path = 'scoring'
231
+ os.makedirs(scoring_path, exist_ok=True)
232
+ if eval_sharegpt_as_output:
233
+ used_base_model = 'gpt35'
234
+ used_lora_weights = ''
235
+ else:
236
+ used_base_model = str(base_model.split('/')[-1])
237
+ used_lora_weights = str(lora_weights.split('/')[-1])
238
+ df_scores = pd.DataFrame(score_dump, columns=eval_func_param_names + ['prompt', 'response', 'score'])
239
+ filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
240
+ eval_sharegpt_prompts_only_seed,
241
+ eval_sharegpt_as_output,
242
+ used_base_model,
243
+ used_lora_weights)
244
+ filename = os.path.join(scoring_path, filename)
245
+ df_scores.to_parquet(filename, index=False)
246
+ # plot histogram so far
247
+ plt.figure(figsize=(10, 10))
248
+ plt.hist(df_scores['score'], bins=20)
249
+ score_avg = np.mean(df_scores['score'])
250
+ score_median = np.median(df_scores['score'])
251
+ plt.title("Score avg: %s median: %s" % (score_avg, score_median))
252
+ plt.savefig(filename.replace('.parquet', '.png'))
253
+ plt.close()
254
+
255
+ print("END" + "=" * 102)
256
+ print("")
257
+ t2 = time.time()
258
+ print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
259
+ t1 = time.time()
260
+ print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
261
+ return
262
+ if gradio:
263
+ go_gradio(**locals())
264
+
265
+
266
+ def get_device():
267
+ if torch.cuda.is_available():
268
+ device = "cuda"
269
+ else:
270
+ raise RuntimeError("only cuda supported")
271
+
272
+ return device
273
+
274
+
275
+ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type, force_1_gpu=True, use_auth_token=False):
276
+ """
277
+ Ensure model gets on correct device
278
+ :param base_model:
279
+ :param model_loader:
280
+ :param load_half:
281
+ :param model_kwargs:
282
+ :param reward_type:
283
+ :return:
284
+ """
285
+ with init_empty_weights():
286
+ from transformers import AutoConfig
287
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
288
+ model = AutoModel.from_config(
289
+ config,
290
+ )
291
+
292
+ # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
293
+ # NOTE: Some models require avoiding sharding some layers,
294
+ # then would pass no_split_module_classes and give list of those layers.
295
+ device_map = infer_auto_device_map(
296
+ model,
297
+ dtype=torch.float16 if load_half else torch.float32,
298
+ )
299
+ if hasattr(model, 'model'):
300
+ device_map_model = infer_auto_device_map(
301
+ model.model,
302
+ dtype=torch.float16 if load_half else torch.float32,
303
+ )
304
+ device_map.update(device_map_model)
305
+ print('device_map: %s' % device_map, flush=True)
306
+
307
+ if force_1_gpu:
308
+ # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
309
+ # So avoid for now, just put on first GPU, unless score_model, put on last
310
+ n_gpus = torch.cuda.device_count()
311
+ if reward_type:
312
+ device_map = {'': n_gpus - 1}
313
+ else:
314
+ device_map = {'': 0}
315
+
316
+ load_in_8bit = model_kwargs.get('load_in_8bit', False)
317
+ model_kwargs['device_map'] = device_map
318
+
319
+ if load_in_8bit or not load_half:
320
+ model = model_loader.from_pretrained(
321
+ base_model,
322
+ **model_kwargs,
323
+ )
324
+ else:
325
+ model = model_loader.from_pretrained(
326
+ base_model,
327
+ **model_kwargs,
328
+ ).half()
329
+ return model
330
+
331
+
332
+ def get_model(
333
+ load_8bit: bool = False,
334
+ load_half: bool = True,
335
+ infer_devices: bool = True,
336
+ base_model: str = '',
337
+ tokenizer_base_model: str = '',
338
+ lora_weights: str = "",
339
+ force_1_gpu: bool = False,
340
+
341
+ llama_type: bool = None,
342
+ reward_type: bool = None,
343
+ local_files_only: bool = False,
344
+ resume_download: bool = True,
345
+ use_auth_token: Union[str, bool] = False,
346
+ compile: bool = True,
347
+ **kwargs,
348
+ ):
349
+ """
350
+
351
+ :param load_8bit: load model in 8-bit, not supported by all models
352
+ :param load_half: load model in 16-bit
353
+ :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
354
+ For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
355
+ So it is not the default
356
+ :param base_model: name/path of base model
357
+ :param tokenizer_base_model: name/path of tokenizer
358
+ :param lora_weights: name/path
359
+ :param force_1_gpu:
360
+ :param llama_type: whether LLaMa type model
361
+ :param reward_type: reward type model for sequence classification
362
+ :param local_files_only: use local files instead of from HF
363
+ :param resume_download: resume downloads from HF
364
+ :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
365
+ :parm compile: whether to compile torch model
366
+ :param kwargs:
367
+ :return:
368
+ """
369
+ print("Get %s model" % base_model, flush=True)
370
+ if lora_weights is not None and lora_weights.strip():
371
+ print("Get %s lora weights" % lora_weights, flush=True)
372
+ device = get_device()
373
+
374
+ if 'gpt2' in base_model.lower():
375
+ # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
376
+ load_8bit = False
377
+
378
+ assert base_model.strip(), (
379
+ "Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
380
+ )
381
+ llama_type = llama_type or "llama" in base_model
382
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
383
+ if not tokenizer_base_model:
384
+ tokenizer_base_model = base_model
385
+
386
+ if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
387
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
388
+ local_files_only=local_files_only,
389
+ resume_download=resume_download,
390
+ use_auth_token=use_auth_token,
391
+ )
392
+ else:
393
+ tokenizer = tokenizer_loader
394
+
395
+ if isinstance(tokenizer, str):
396
+ # already a pipeline, tokenizer_loader is string for task
397
+ model = model_loader(tokenizer,
398
+ model=base_model,
399
+ device=0 if device == "cuda" else -1,
400
+ torch_dtype=torch.float16)
401
+ else:
402
+ assert device == "cuda", "Unsupported device %s" % device
403
+ model_kwargs = dict(local_files_only=local_files_only,
404
+ torch_dtype=torch.float16,
405
+ resume_download=resume_download,
406
+ use_auth_token=use_auth_token)
407
+ if 'mbart-' not in base_model.lower():
408
+ model_kwargs.update(dict(load_in_8bit=load_8bit,
409
+ device_map={"": 0} if load_8bit else "auto",
410
+ ))
411
+ if 'OpenAssistant/reward-model'.lower() in base_model.lower():
412
+ # could put on other GPUs
413
+ model_kwargs['device_map'] = {"": 0}
414
+ model_kwargs.pop('torch_dtype', None)
415
+
416
+ if not lora_weights:
417
+ with torch.device("cuda"):
418
+ if infer_devices:
419
+ model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
420
+ force_1_gpu=force_1_gpu, use_auth_token=use_auth_token)
421
+ else:
422
+ if load_half and not load_8bit:
423
+ model = model_loader.from_pretrained(
424
+ base_model,
425
+ **model_kwargs).half()
426
+ else:
427
+ model = model_loader.from_pretrained(
428
+ base_model,
429
+ **model_kwargs)
430
+ elif load_8bit:
431
+ model = model_loader.from_pretrained(
432
+ base_model,
433
+ **model_kwargs
434
+ )
435
+ model = PeftModel.from_pretrained(
436
+ model,
437
+ lora_weights,
438
+ torch_dtype=torch.float16,
439
+ local_files_only=local_files_only,
440
+ resume_download=resume_download,
441
+ use_auth_token=use_auth_token,
442
+ device_map={"": 0}, # seems to be required
443
+ )
444
+ else:
445
+ with torch.device("cuda"):
446
+ model = model_loader.from_pretrained(
447
+ base_model,
448
+ **model_kwargs
449
+ )
450
+ model = PeftModel.from_pretrained(
451
+ model,
452
+ lora_weights,
453
+ torch_dtype=torch.float16,
454
+ local_files_only=local_files_only,
455
+ resume_download=resume_download,
456
+ use_auth_token=use_auth_token,
457
+ device_map="auto",
458
+ )
459
+ if load_half:
460
+ model.half()
461
+
462
+ # unwind broken decapoda-research config
463
+ if llama_type:
464
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
465
+ model.config.bos_token_id = 1
466
+ model.config.eos_token_id = 2
467
+ if 'gpt2' in base_model.lower():
468
+ # add special tokens that otherwise all share the same id
469
+ tokenizer.add_special_tokens({'bos_token': '<bos>',
470
+ 'eos_token': '<eos>',
471
+ 'pad_token': '<pad>'})
472
+
473
+ if not isinstance(tokenizer, str):
474
+ model.eval()
475
+ if torch.__version__ >= "2" and sys.platform != "win32" and compile:
476
+ model = torch.compile(model)
477
+
478
+ return model, tokenizer, device
479
+
480
+
481
+ def get_score_model(**kwargs):
482
+ # score model
483
+ if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
484
+ score_all_kwargs = kwargs.copy()
485
+ score_all_kwargs['load_8bit'] = False
486
+ score_all_kwargs['load_half'] = False
487
+ score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
488
+ score_all_kwargs['tokenizer_base_model'] = ''
489
+ score_all_kwargs['lora_weights'] = ''
490
+ score_all_kwargs['llama_type'] = False
491
+ score_all_kwargs['compile'] = False
492
+ smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
493
+ else:
494
+ smodel, stokenizer, sdevice = None, None, None
495
+ return smodel, stokenizer, sdevice
496
+
497
+
498
+ def go_gradio(**kwargs):
499
+
500
+ # get default model
501
+ all_kwargs = kwargs.copy()
502
+ all_kwargs.update(locals())
503
+ if kwargs.get('base_model') and not kwargs['login_mode_if_model0']:
504
+ model0, tokenizer0, device = get_model(**all_kwargs)
505
+ else:
506
+ # if empty model, then don't load anything, just get gradio up
507
+ model0, tokenizer0, device = None, None, None
508
+ model_state0 = [model0, tokenizer0, device, kwargs['base_model']]
509
+
510
+ # get score model
511
+ smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
512
+
513
+ if 'mbart-' in kwargs['model_lower']:
514
+ instruction_label = "Text to translate"
515
+ else:
516
+ instruction_label = "Instruction"
517
+ if kwargs['chat']:
518
+ instruction_label = "You (Shift-Enter or push Submit to send message)"
519
+
520
+ title = 'h2oGPT'
521
+ if kwargs['verbose']:
522
+ description = f"""Model {kwargs['base_model']} Instruct dataset.
523
+ For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).
524
+ Command: {str(' '.join(sys.argv))}
525
+ Hash: {get_githash()}
526
+ """
527
+ else:
528
+ description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
529
+ if os.environ.get("HUGGINGFACE_SPACES"):
530
+ description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
531
+ if kwargs['load_8bit']:
532
+ description += """<i><li> Model is loaded in 8-bit and HF spaces version has other limitations in order to fit on HF GPUs, so UX can be worse than native app.</i></li>"""
533
+ description += """<i><li>Model loading and unloading disabled on HF SPACES to avoid GPU OOM for multi-user environment.</i></li></ul></p>"""
534
+
535
+ if kwargs['verbose']:
536
+ task_info_md = f"""
537
+ ### Task: {kwargs['task_info']}"""
538
+ else:
539
+ task_info_md = ''
540
+
541
+ css_code = """footer {visibility: hidden}
542
+ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}}"""
543
+
544
+ from gradio.themes.utils import colors, fonts, sizes
545
+ if kwargs['h2ocolors']:
546
+ colors_dict = dict(primary_hue=colors.yellow,
547
+ secondary_hue=colors.yellow,
548
+ neutral_hue=colors.gray,
549
+ spacing_size=sizes.spacing_md,
550
+ radius_size=sizes.radius_md,
551
+ text_size=sizes.text_md,
552
+ )
553
+ else:
554
+ colors_dict = dict(primary_hue=colors.indigo,
555
+ secondary_hue=colors.indigo,
556
+ neutral_hue=colors.gray,
557
+ spacing_size=sizes.spacing_md,
558
+ radius_size=sizes.radius_md,
559
+ text_size=sizes.text_md,
560
+ )
561
+
562
+ import gradio as gr
563
+
564
+ if kwargs['gradio_avoid_processing_markdown']:
565
+ from gradio_client import utils as client_utils
566
+ from gradio.components import Chatbot
567
+
568
+ # gradio has issue with taking too long to process input/output for markdown etc.
569
+ # Avoid for now, allow raw html to render, good enough for chatbot.
570
+ def _postprocess_chat_messages(self, chat_message: str):
571
+ if chat_message is None:
572
+ return None
573
+ elif isinstance(chat_message, (tuple, list)):
574
+ filepath = chat_message[0]
575
+ mime_type = client_utils.get_mimetype(filepath)
576
+ filepath = self.make_temp_copy_if_needed(filepath)
577
+ return {
578
+ "name": filepath,
579
+ "mime_type": mime_type,
580
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
581
+ "data": None, # These last two fields are filled in by the frontend
582
+ "is_file": True,
583
+ }
584
+ elif isinstance(chat_message, str):
585
+ return chat_message
586
+ else:
587
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
588
+ Chatbot._postprocess_chat_messages = _postprocess_chat_messages
589
+
590
+ demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
591
+ callback = gr.CSVLogger()
592
+ # css_code = 'body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}'
593
+ # demo = gr.Blocks(theme='gstaff/xkcd', css=css_code)
594
+
595
+ model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
596
+ if kwargs['base_model'].strip() not in model_options:
597
+ lora_options = [kwargs['base_model'].strip()] + model_options
598
+ lora_options = kwargs['extra_lora_options']
599
+ if kwargs['lora_weights'].strip() not in lora_options:
600
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
601
+ # always add in no lora case
602
+ # add fake space so doesn't go away in gradio dropdown
603
+ lora_options = [' '] + kwargs['extra_lora_options']
604
+
605
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get('base_model') else 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
606
+
607
+ with demo:
608
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
609
+ # https://github.com/gradio-app/gradio/issues/3558
610
+ model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
611
+ model_options_state = gr.State([model_options])
612
+ lora_options_state = gr.State([lora_options])
613
+ gr.Markdown(
614
+ f"""
615
+ <h1 align="center"> {title}</h1>
616
+
617
+ {description}
618
+ {task_info_md}
619
+ """)
620
+ if os.environ.get("HUGGINGFACE_SPACES"):
621
+ gr.HTML('''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
622
+
623
+ # go button visible if
624
+ base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
625
+ go_btn = gr.Button(value="LOGIN", visible=base_wanted, variant="primary")
626
+ normal_block = gr.Row(visible=not base_wanted)
627
+ with normal_block:
628
+ with gr.Tabs():
629
+ with gr.Row():
630
+ if not kwargs['chat']:
631
+ with gr.Column():
632
+ instruction = gr.Textbox(
633
+ lines=4, label=instruction_label,
634
+ placeholder=kwargs['placeholder_instruction'],
635
+ )
636
+ iinput = gr.Textbox(lines=4, label="Input",
637
+ placeholder=kwargs['placeholder_input'])
638
+ flag_btn = gr.Button("Flag")
639
+ if kwargs['score_model']:
640
+ if not kwargs['auto_score']:
641
+ with gr.Column():
642
+ score_btn = gr.Button("Score last prompt & response")
643
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
644
+ else:
645
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
646
+ with gr.Column():
647
+ if kwargs['chat']:
648
+ text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
649
+ with gr.Row():
650
+ with gr.Column(scale=50):
651
+ instruction = gr.Textbox(
652
+ lines=4, label=instruction_label,
653
+ placeholder=kwargs['placeholder_instruction'],
654
+ )
655
+ with gr.Row(): # .style(equal_height=False, equal_width=False):
656
+ submit = gr.Button(value='Submit').style(full_width=False, size='sm')
657
+ stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
658
+ with gr.Row():
659
+ clear = gr.Button("New Conversation")
660
+ flag_btn = gr.Button("Flag")
661
+ if kwargs['score_model']:
662
+ if not kwargs['auto_score']:
663
+ with gr.Column():
664
+ score_btn = gr.Button("Score last prompt & response").style(full_width=False, size='sm')
665
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
666
+ else:
667
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
668
+ retry = gr.Button("Regenerate")
669
+ undo = gr.Button("Undo")
670
+ else:
671
+ text_output = gr.Textbox(lines=5, label=output_label0)
672
+ with gr.TabItem("Input/Output"):
673
+ with gr.Row():
674
+ if 'mbart-' in kwargs['model_lower']:
675
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
676
+ value=kwargs['src_lang'],
677
+ label="Input Language")
678
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
679
+ value=kwargs['tgt_lang'],
680
+ label="Output Language")
681
+ with gr.TabItem("Expert"):
682
+ with gr.Row():
683
+ with gr.Column():
684
+ stream_output = gr.components.Checkbox(label="Stream output",
685
+ value=kwargs['stream_output'])
686
+ prompt_type = gr.Dropdown(prompt_types_strings,
687
+ value=kwargs['prompt_type'], label="Prompt Type",
688
+ visible=not os.environ.get("HUGGINGFACE_SPACES"))
689
+ temperature = gr.Slider(minimum=0, maximum=3,
690
+ value=kwargs['temperature'],
691
+ label="Temperature",
692
+ info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
693
+ top_p = gr.Slider(minimum=0, maximum=1,
694
+ value=kwargs['top_p'], label="Top p",
695
+ info="Cumulative probability of tokens to sample from")
696
+ top_k = gr.Slider(
697
+ minimum=0, maximum=100, step=1,
698
+ value=kwargs['top_k'], label="Top k",
699
+ info='Num. tokens to sample from'
700
+ )
701
+ max_beams = 8 if not os.environ.get("HUGGINGFACE_SPACES") else 2
702
+ num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
703
+ value=min(max_beams, kwargs['num_beams']), label="Beams",
704
+ info="Number of searches for optimal overall probability. "
705
+ "Uses more GPU memory/compute")
706
+ max_max_new_tokens = 2048 if not os.environ.get("HUGGINGFACE_SPACES") else kwargs['max_new_tokens']
707
+ max_new_tokens = gr.Slider(
708
+ minimum=1, maximum=max_max_new_tokens, step=1,
709
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
710
+ )
711
+ min_new_tokens = gr.Slider(
712
+ minimum=0, maximum=max_max_new_tokens, step=1,
713
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
714
+ )
715
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
716
+ value=kwargs['early_stopping'])
717
+ max_max_time = 60 * 5 if not os.environ.get("HUGGINGFACE_SPACES") else 60
718
+ max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
719
+ value=min(max_max_time, kwargs['max_time']), label="Max. time",
720
+ info="Max. time to search optimal output.")
721
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
722
+ value=kwargs['repetition_penalty'],
723
+ label="Repetition Penalty")
724
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
725
+ value=kwargs['num_return_sequences'],
726
+ label="Number Returns", info="Must be <= num_beams",
727
+ visible=not os.environ.get("HUGGINGFACE_SPACES"))
728
+ do_sample = gr.Checkbox(label="Sample", info="Sample, for diverse output(s)",
729
+ value=kwargs['do_sample'])
730
+ if kwargs['chat']:
731
+ iinput = gr.Textbox(lines=4, label="Input",
732
+ placeholder=kwargs['placeholder_input'],
733
+ visible=not os.environ.get("HUGGINGFACE_SPACES"))
734
+ # nominally empty for chat mode
735
+ context = gr.Textbox(lines=1, label="Context",
736
+ info="Ignored in chat mode.",
737
+ visible=not os.environ.get("HUGGINGFACE_SPACES"))
738
+
739
+ with gr.TabItem("Models"):
740
+ with gr.Row():
741
+ with gr.Column():
742
+ with gr.Row(scale=1):
743
+ with gr.Column(scale=50):
744
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", value=kwargs['base_model'])
745
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
746
+ with gr.Column(scale=1):
747
+ load_msg = "Load Model/LORA" if not os.environ.get("HUGGINGFACE_SPACES") \
748
+ else "LOAD DISABLED ON HF SPACES"
749
+ load_model_button = gr.Button(load_msg)
750
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
751
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
752
+ with gr.Row(scale=1):
753
+ with gr.Column(scale=50):
754
+ new_model = gr.Textbox(label="New Model HF name/path")
755
+ new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
756
+ with gr.Column(scale=1):
757
+ add_model_button = gr.Button("Add new model name")
758
+ add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
759
+ with gr.TabItem("System"):
760
+ with gr.Row():
761
+ with gr.Column():
762
+ system_text = gr.Textbox(label='System Info')
763
+ system_btn = gr.Button(value='Get System Info')
764
+
765
+
766
+ inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
767
+ from functools import partial
768
+ all_kwargs = kwargs.copy()
769
+ all_kwargs.update(locals())
770
+ kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
771
+ fun = partial(evaluate,
772
+ **kwargs_evaluate)
773
+
774
+ dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
775
+ size="sm",
776
+ )
777
+ dark_mode_btn.click(
778
+ None,
779
+ None,
780
+ None,
781
+ _js="""() => {
782
+ if (document.querySelectorAll('.dark').length) {
783
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
784
+ } else {
785
+ document.querySelector('body').classList.add('dark');
786
+ }
787
+ }""",
788
+ api_name="dark",
789
+ )
790
+ if not kwargs['chat']:
791
+ submit = gr.Button("Submit")
792
+ submit_event = submit.click(fun, inputs=[model_state] + inputs_list, outputs=text_output, api_name='submit')
793
+
794
+ # examples after submit or any other buttons for chat or no chat
795
+ if kwargs['examples'] is not None and kwargs['show_examples']:
796
+ gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
797
+
798
+ # Score
799
+ def score_last_response(*args):
800
+ """ Similar to user() """
801
+ args_list = list(args)
802
+ history = args_list[-1]
803
+ if history is None:
804
+ print("Bad history in scoring last response, fix for now", flush=True)
805
+ history = []
806
+ if smodel is not None and \
807
+ stokenizer is not None and \
808
+ sdevice is not None and \
809
+ history is not None and len(history) > 0 and \
810
+ history[-1] is not None and \
811
+ len(history[-1]) >= 2:
812
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
813
+
814
+ max_length_tokenize = 512 if os.environ.get("HUGGINGFACE_SPACES") else 2048
815
+ cutoff_len = max_length_tokenize*4 # restrict deberta related to max for LLM
816
+
817
+ question = history[-1][0]
818
+ question = question[-cutoff_len:]
819
+
820
+ answer = history[-1][1]
821
+ answer = answer[-cutoff_len:]
822
+
823
+ inputs = stokenizer(question, answer,
824
+ return_tensors="pt",
825
+ truncation=True,
826
+ max_length=max_length_tokenize).to(smodel.device)
827
+ try:
828
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
829
+ except torch.cuda.OutOfMemoryError as e:
830
+ print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
831
+ del inputs
832
+ traceback.print_exc()
833
+ clear_torch_cache()
834
+ return 'Response Score: GPU OOM'
835
+ except RuntimeError as e:
836
+ if 'Expected all tensors to be on the same device' in str(e) or 'expected scalar type Half but found Float' in str(e):
837
+ print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
838
+ traceback.print_exc()
839
+ clear_torch_cache()
840
+ return 'Response Score: GPU Error'
841
+ else:
842
+ raise
843
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
844
+ return 'Response Score: {:.1%}'.format(score)
845
+ else:
846
+ return 'Response Score: NA'
847
+
848
+ if kwargs['score_model']:
849
+ score_args = dict(fn=score_last_response,
850
+ inputs=inputs_list + [text_output],
851
+ outputs=[score_text],
852
+ )
853
+ if not kwargs['auto_score']:
854
+ score_event = score_btn.click(**score_args, queue=stream_output, api_name='score')
855
+
856
+ if kwargs['chat']:
857
+ def user(*args, undo=False, sanitize_user_prompt=True):
858
+ args_list = list(args)
859
+ user_message = args_list[0]
860
+ input1 = args_list[1]
861
+ context1 = args_list[2]
862
+ if input1 and not user_message.endswith(':'):
863
+ user_message1 = user_message + ":" + input1
864
+ elif input1:
865
+ user_message1 = user_message + input1
866
+ else:
867
+ user_message1 = user_message
868
+ if sanitize_user_prompt:
869
+ from better_profanity import profanity
870
+ user_message1 = profanity.censor(user_message1)
871
+
872
+ history = args_list[-1]
873
+ if undo and history:
874
+ history.pop()
875
+ args_list = args_list[:-1]
876
+ if history is None:
877
+ print("Bad history, fix for now", flush=True)
878
+ history = []
879
+ if undo:
880
+ return "", history
881
+ else:
882
+ return "", history + [[user_message1, None]]
883
+
884
+ def bot(*args, retry=False):
885
+ args_list = list(args)
886
+ history = args_list[-1]
887
+ if retry and history:
888
+ history.pop()
889
+ if not history:
890
+ print("No history", flush=True)
891
+ return
892
+ instruction1 = history[-1][0]
893
+ context1 = ''
894
+ if kwargs['chat_history'] > 0:
895
+ prompt_type1 = args_list[prompt_type_arg_id]
896
+ context1 = ''
897
+ for histi in range(len(history) - 1):
898
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
899
+ context1 += generate_prompt(data_point, prompt_type1, kwargs['chat'], reduced=True)[0].replace(
900
+ '<br>', '\n')
901
+ if not context1.endswith('\n'):
902
+ context1 += '\n'
903
+ if context1 and not context1.endswith('\n'):
904
+ context1 += '\n' # ensure if terminates abruptly, then human continues on next line
905
+ args_list[0] = instruction1
906
+ # only include desired chat history
907
+ args_list[2] = context1[-kwargs['chat_history']:]
908
+ model_state1 = args_list[-2]
909
+ args_list = args_list[:-2]
910
+ fun1 = partial(evaluate,
911
+ model_state1,
912
+ **kwargs_evaluate)
913
+ try:
914
+ for output in fun1(*tuple(args_list)):
915
+ bot_message = output
916
+ history[-1][1] = bot_message
917
+ yield history
918
+ except StopIteration:
919
+ yield history
920
+ except RuntimeError as e:
921
+ if "generator raised StopIteration" in str(e):
922
+ # assume last entry was bad, undo
923
+ history.pop()
924
+ yield history
925
+ raise
926
+ except Exception as e:
927
+ # put error into user input
928
+ history[-1][0] = "Exception: %s" % str(e)
929
+ yield history
930
+ raise
931
+ return
932
+
933
+ user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
934
+ inputs=inputs_list + [text_output],
935
+ outputs=[instruction, text_output],
936
+ )
937
+ bot_args = dict(fn=bot,
938
+ inputs=inputs_list + [model_state] + [text_output],
939
+ outputs=[text_output],
940
+ )
941
+ retry_bot_args = dict(fn=functools.partial(bot, retry=True),
942
+ inputs=inputs_list + [model_state] + [text_output],
943
+ outputs=[text_output],
944
+ )
945
+ undo_user_args = dict(fn=functools.partial(user, undo=True),
946
+ inputs=inputs_list + [text_output],
947
+ outputs=[instruction, text_output],
948
+ )
949
+
950
+ if kwargs['auto_score']:
951
+ submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
952
+ **bot_args, api_name='instruction_bot',
953
+ ).then(**score_args, api_name='instruction_bot_score').then(clear_torch_cache)
954
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
955
+ **bot_args, api_name='submit_bot',
956
+ ).then(**score_args, api_name='submit_bot_score').then(clear_torch_cache)
957
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
958
+ **retry_bot_args, api_name='retry_bot',
959
+ ).then(**score_args, api_name='retry_bot_score').then(clear_torch_cache)
960
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo').then(**score_args, api_name='undo_score')
961
+ else:
962
+ submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
963
+ **bot_args, api_name='instruction_bot',
964
+ ).then(clear_torch_cache)
965
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
966
+ **bot_args, api_name='submit_bot',
967
+ ).then(clear_torch_cache)
968
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
969
+ **retry_bot_args, api_name='retry_bot',
970
+ ).then(clear_torch_cache)
971
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo')
972
+ clear.click(lambda: None, None, text_output, queue=False, api_name='clear')
973
+
974
+ def load_model(model_name, lora_weights, model_state_old, prompt_type_old):
975
+ # ensure old model removed from GPU memory
976
+ if kwargs['debug']:
977
+ print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
978
+
979
+ if isinstance(model_state_old[0], str) and model0 is not None:
980
+ # best can do, move model loaded at first to CPU
981
+ model0.cpu()
982
+
983
+ if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
984
+ try:
985
+ model_state_old[0].cpu()
986
+ except Exception as e:
987
+ # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
988
+ print("Unable to put model on CPU: %s" % str(e), flush=True)
989
+ del model_state_old[0]
990
+ model_state_old[0] = None
991
+
992
+ if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
993
+ del model_state_old[1]
994
+ model_state_old[1] = None
995
+
996
+ clear_torch_cache()
997
+ if kwargs['debug']:
998
+ print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
999
+ all_kwargs['base_model'] = model_name.strip()
1000
+ model_lower = model_name.strip().lower()
1001
+ if model_lower in inv_prompt_type_to_model_lower:
1002
+ prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
1003
+ else:
1004
+ prompt_type1 = prompt_type_old
1005
+
1006
+ all_kwargs['lora_weights'] = lora_weights.strip()
1007
+ model1, tokenizer1, device1 = get_model(**all_kwargs)
1008
+ clear_torch_cache()
1009
+
1010
+ if kwargs['debug']:
1011
+ print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
1012
+ return {model_state: [model1, tokenizer1, device1, model_name],
1013
+ model_used: model_name,
1014
+ lora_used: lora_weights,
1015
+ prompt_type: prompt_type1}
1016
+
1017
+ def dropdown_prompt_type_list(x):
1018
+ return gr.Dropdown.update(value=x)
1019
+
1020
+ def chatbot_list(x, model_used_in):
1021
+ return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
1022
+
1023
+ load_model_args = dict(fn=load_model,
1024
+ inputs=[model_choice, lora_choice, model_state, prompt_type],
1025
+ outputs=[model_state, model_used, lora_used, prompt_type])
1026
+ prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
1027
+ chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
1028
+ if not os.environ.get("HUGGINGFACE_SPACES"):
1029
+ load_model_event = load_model_button.click(**load_model_args) \
1030
+ .then(**prompt_update_args) \
1031
+ .then(**chatbot_update_args) \
1032
+ .then(clear_torch_cache)
1033
+
1034
+ def dropdown_model_list(list0, x):
1035
+ new_state = [list0[0] + [x]]
1036
+ new_options = [*new_state[0]]
1037
+ return gr.Dropdown.update(value=x, choices=new_options), '', new_state
1038
+
1039
+ add_model_event = add_model_button.click(fn=dropdown_model_list,
1040
+ inputs=[model_options_state, new_model],
1041
+ outputs=[model_choice, new_model, model_options_state])
1042
+
1043
+ def dropdown_lora_list(list0, x):
1044
+ new_state = [list0[0] + [x]]
1045
+ new_options = [*new_state[0]]
1046
+ return gr.Dropdown.update(value=x, choices=new_options), '', new_state
1047
+
1048
+ add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
1049
+ inputs=[lora_options_state, new_lora],
1050
+ outputs=[lora_choice, new_lora, lora_options_state])
1051
+
1052
+ go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
1053
+ .then(lambda: gr.update(visible=True), None, normal_block) \
1054
+ .then(**load_model_args).then(**prompt_update_args)
1055
+
1056
+ # callback for logging flagged input/output
1057
+ callback.setup(inputs_list + [text_output], "flagged_data_points")
1058
+ flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
1059
+ api_name='flag')
1060
+
1061
+ def get_system_info():
1062
+ return gr.Textbox.update(value=system_info_print())
1063
+
1064
+ system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info')
1065
+
1066
+ if kwargs['chat']:
1067
+
1068
+ # don't pass text_output, don't want to clear output, just stop it
1069
+ # FIXME: have to click once to stop output and second time to stop GPUs going
1070
+ stop_btn.click(lambda: None, None, None, cancels=[submit_event, submit_event2, submit_event3],
1071
+ queue=False, api_name='stop').then(clear_torch_cache)
1072
+
1073
+ demo.queue(concurrency_count=1)
1074
+ favicon_path = "h2o-logo.svg"
1075
+ demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1076
+ favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
1077
+ print("Started GUI", flush=True)
1078
+ demo.block_thread()
1079
+
1080
+
1081
+ input_args_list = ['model_state']
1082
+ inputs_kwargs_list = ['debug', 'chat', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
1083
+
1084
+
1085
+ def get_inputs_list(inputs_dict, model_lower):
1086
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
1087
+ inputs_list = []
1088
+ for k in inputs_list_names:
1089
+ if k == 'kwargs':
1090
+ continue
1091
+ if k in input_args_list + inputs_kwargs_list:
1092
+ # these are added via partial, not taken as input
1093
+ continue
1094
+ if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
1095
+ continue
1096
+ inputs_list.append(inputs_dict[k])
1097
+ return inputs_list
1098
+
1099
+
1100
+ # index of prompt_type in evaluate function, after model_state
1101
+ prompt_type_arg_id = 4
1102
+
1103
+ eval_func_param_names = ['instruction',
1104
+ 'iinput',
1105
+ 'context',
1106
+ 'stream_output',
1107
+ 'prompt_type',
1108
+ 'temperature',
1109
+ 'top_p',
1110
+ 'top_k',
1111
+ 'num_beams',
1112
+ 'max_new_tokens',
1113
+ 'min_new_tokens',
1114
+ 'early_stopping',
1115
+ 'max_time',
1116
+ 'repetition_penalty',
1117
+ 'num_return_sequences',
1118
+ 'do_sample',
1119
+ ]
1120
+
1121
+
1122
+ def evaluate(
1123
+ model_state,
1124
+ # START NOTE: Examples must have same order of parameters
1125
+ instruction,
1126
+ iinput,
1127
+ context,
1128
+ stream_output,
1129
+ prompt_type,
1130
+ temperature,
1131
+ top_p,
1132
+ top_k,
1133
+ num_beams,
1134
+ max_new_tokens,
1135
+ min_new_tokens,
1136
+ early_stopping,
1137
+ max_time,
1138
+ repetition_penalty,
1139
+ num_return_sequences,
1140
+ do_sample,
1141
+ # END NOTE: Examples must have same order of parameters
1142
+ src_lang=None,
1143
+ tgt_lang=None,
1144
+ debug=False,
1145
+ chat=False,
1146
+ hard_stop_list=None,
1147
+ sanitize_bot_response=True,
1148
+ model_state0=None,
1149
+ **kwargs,
1150
+ ):
1151
+ if debug:
1152
+ locals_dict = locals().copy()
1153
+ locals_dict.pop('model_state', None)
1154
+ print(locals_dict)
1155
+
1156
+ no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
1157
+
1158
+ if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
1159
+ # try to free-up original model (i.e. list was passed as reference)
1160
+ if model_state0 is not None and model_state0[0] is not None:
1161
+ model_state0[0].cpu()
1162
+ model_state0[0] = None
1163
+ # try to free-up original tokenizer (i.e. list was passed as reference)
1164
+ if model_state0 is not None and model_state0[1] is not None:
1165
+ model_state0[1] = None
1166
+ clear_torch_cache()
1167
+ model, tokenizer, device, base_model = model_state
1168
+ elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
1169
+ assert isinstance(model_state[0], str)
1170
+ model, tokenizer, device, base_model = model_state0
1171
+ else:
1172
+ raise AssertionError(no_model_msg)
1173
+
1174
+ assert base_model.strip(), no_model_msg
1175
+ assert model, "Model is missing"
1176
+ assert tokenizer, "Tokenizer is missing"
1177
+
1178
+ data_point = dict(context=context, instruction=instruction, input=iinput)
1179
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
1180
+ prompt = prompter.generate_prompt(data_point)
1181
+
1182
+ if hard_stop_list is None:
1183
+ # acts like undo on user entry and bot response
1184
+ hard_stop_list = []
1185
+
1186
+ if isinstance(tokenizer, str):
1187
+ # pipeline
1188
+ if tokenizer == "summarization":
1189
+ key = 'summary_text'
1190
+ else:
1191
+ raise RuntimeError("No such task type %s" % tokenizer)
1192
+ # NOTE: uses max_length only
1193
+ yield model(prompt, max_length=max_new_tokens)[0][key]
1194
+
1195
+ if 'mbart-' in base_model.lower():
1196
+ assert src_lang is not None
1197
+ tokenizer.src_lang = languages_covered()[src_lang]
1198
+
1199
+ if chat:
1200
+ # override, ignore user change
1201
+ num_return_sequences = 1
1202
+ if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
1203
+ if prompt_type == 'human_bot':
1204
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
1205
+ # stopping only starts once output is beyond prompt
1206
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
1207
+ stop_words = [human, bot]
1208
+ encounters = [1, 2]
1209
+ elif prompt_type == 'instruct_vicuna':
1210
+ # even below is not enough, generic strings and many ways to encode
1211
+ stop_words = [
1212
+ '### Human:',
1213
+ """
1214
+ ### Human:""",
1215
+ """
1216
+ ### Human:
1217
+ """,
1218
+ '### Assistant:',
1219
+ """
1220
+ ### Assistant:""",
1221
+ """
1222
+ ### Assistant:
1223
+ """,
1224
+ ]
1225
+ encounters = [1, 2]
1226
+ else:
1227
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
1228
+ stop_words = ['### End']
1229
+ encounters = [1]
1230
+ stop_words_ids = [
1231
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
1232
+ # handle single token case
1233
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
1234
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
1235
+ # avoid padding in front of tokens
1236
+ if tokenizer.pad_token:
1237
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
1238
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
1239
+ else:
1240
+ stopping_criteria = StoppingCriteriaList()
1241
+
1242
+ # help to avoid errors like:
1243
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1244
+ # RuntimeError: expected scalar type Half but found Float
1245
+ # with - 256
1246
+ max_length_tokenize = 768 - 256 if os.environ.get("HUGGINGFACE_SPACES") else 2048 - 256
1247
+ cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1248
+ output_smallest = 30 * 4
1249
+ prompt = prompt[-cutoff_len - output_smallest:]
1250
+ inputs = tokenizer(prompt,
1251
+ return_tensors="pt",
1252
+ truncation=True,
1253
+ max_length=max_length_tokenize)
1254
+ if debug and len(inputs["input_ids"]) > 0:
1255
+ print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1256
+ input_ids = inputs["input_ids"].to(device)
1257
+ generation_config = GenerationConfig(
1258
+ temperature=float(temperature),
1259
+ top_p=float(top_p),
1260
+ top_k=top_k,
1261
+ num_beams=num_beams,
1262
+ do_sample=do_sample,
1263
+ repetition_penalty=float(repetition_penalty),
1264
+ num_return_sequences=num_return_sequences,
1265
+ renormalize_logits=True,
1266
+ remove_invalid_values=True,
1267
+ **kwargs,
1268
+ )
1269
+
1270
+ gen_kwargs = dict(input_ids=input_ids,
1271
+ generation_config=generation_config,
1272
+ return_dict_in_generate=True,
1273
+ output_scores=True,
1274
+ max_new_tokens=max_new_tokens, # prompt + new
1275
+ min_new_tokens=min_new_tokens, # prompt + new
1276
+ early_stopping=early_stopping, # False, True, "never"
1277
+ max_time=max_time,
1278
+ stopping_criteria=stopping_criteria,
1279
+ )
1280
+ if 'gpt2' in base_model.lower():
1281
+ gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
1282
+ elif 'mbart-' in base_model.lower():
1283
+ assert tgt_lang is not None
1284
+ tgt_lang = languages_covered()[tgt_lang]
1285
+ gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1286
+ else:
1287
+ gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
1288
+
1289
+ decoder = functools.partial(tokenizer.decode,
1290
+ skip_special_tokens=True,
1291
+ clean_up_tokenization_spaces=True,
1292
+ )
1293
+ decoder_raw = functools.partial(tokenizer.decode,
1294
+ skip_special_tokens=False,
1295
+ clean_up_tokenization_spaces=True,
1296
+ )
1297
+
1298
+ with torch.no_grad():
1299
+ # decoded tokenized prompt can deviate from prompt due to special characters
1300
+ inputs_decoded = decoder(input_ids[0])
1301
+ inputs_decoded_raw = decoder_raw(input_ids[0])
1302
+ if inputs_decoded == prompt:
1303
+ # normal
1304
+ pass
1305
+ elif inputs_decoded.lstrip() == prompt.lstrip():
1306
+ # sometimes extra space in front, make prompt same for prompt removal
1307
+ prompt = inputs_decoded
1308
+ elif inputs_decoded_raw == prompt:
1309
+ # some models specify special tokens that are part of normal prompt, so can't skip them
1310
+ inputs_decoded_raw = inputs_decoded
1311
+ decoder = decoder_raw
1312
+ else:
1313
+ print("WARNING: Special characters in prompt", flush=True)
1314
+ if stream_output:
1315
+ def generate(callback=None, **kwargs):
1316
+ # re-order stopping so Stream first and get out all chunks before stop for other reasons
1317
+ stopping_criteria0 = kwargs.get('stopping_criteria', StoppingCriteriaList()).copy()
1318
+ kwargs['stopping_criteria'] = StoppingCriteriaList()
1319
+ kwargs['stopping_criteria'].append(Stream(func=callback))
1320
+ for stopping_criteria1 in stopping_criteria0:
1321
+ kwargs['stopping_criteria'].append(stopping_criteria1)
1322
+
1323
+ try:
1324
+ model.generate(**kwargs)
1325
+ except torch.cuda.OutOfMemoryError as e:
1326
+ print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1327
+ flush=True)
1328
+ if kwargs['input_ids'] is not None:
1329
+ kwargs['input_ids'].cpu()
1330
+ kwargs['input_ids'] = None
1331
+ traceback.print_exc()
1332
+ clear_torch_cache()
1333
+ return
1334
+ except RuntimeError as e:
1335
+ if 'Expected all tensors to be on the same device' in str(
1336
+ e) or 'expected scalar type Half but found Float' in str(e):
1337
+ print(
1338
+ "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1339
+ flush=True)
1340
+ traceback.print_exc()
1341
+ clear_torch_cache()
1342
+ return
1343
+ else:
1344
+ raise
1345
+
1346
+ for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
1347
+ decoded_output = decoder(output)
1348
+ if output[-1] in [tokenizer.eos_token_id]:
1349
+ if debug:
1350
+ print("HIT EOS", flush=True)
1351
+ break
1352
+ if any(ele in decoded_output for ele in hard_stop_list):
1353
+ raise StopIteration
1354
+ yield prompter.get_response(decoded_output, prompt=inputs_decoded,
1355
+ sanitize_bot_response=sanitize_bot_response)
1356
+ return
1357
+ else:
1358
+ outputs = model.generate(**gen_kwargs)
1359
+ outputs = [decoder(s) for s in outputs.sequences]
1360
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
1361
+ sanitize_bot_response=sanitize_bot_response)
1362
+
1363
+
1364
+ def get_generate_params(model_lower, chat,
1365
+ stream_output, show_examples,
1366
+ prompt_type, temperature, top_p, top_k, num_beams,
1367
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
1368
+ repetition_penalty, num_return_sequences,
1369
+ do_sample):
1370
+ use_defaults = False
1371
+ use_default_examples = True
1372
+ examples = []
1373
+ task_info = f"{prompt_type}"
1374
+ if model_lower:
1375
+ print(f"Using Model {model_lower}", flush=True)
1376
+ else:
1377
+ print("No model defined yet", flush=True)
1378
+
1379
+ min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
1380
+ early_stopping = early_stopping if early_stopping is not None else False
1381
+ max_time_defaults = 60 * 3
1382
+ max_time = max_time if max_time is not None else max_time_defaults
1383
+
1384
+ if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1385
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
1386
+
1387
+ if show_examples is None:
1388
+ if chat:
1389
+ show_examples = False
1390
+ else:
1391
+ show_examples = True
1392
+
1393
+ summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
1394
+ Philipp: Sure you can use the new Hugging Face Deep Learning Container.
1395
+ Jeff: ok.
1396
+ Jeff: and how can I get started?
1397
+ Jeff: where can I find documentation?
1398
+ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
1399
+
1400
+ if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
1401
+ placeholder_instruction = summarize_example1
1402
+ placeholder_input = ""
1403
+ use_defaults = True
1404
+ use_default_examples = False
1405
+ examples += [
1406
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1407
+ 1.0, 1,
1408
+ False]]
1409
+ task_info = "Summarization"
1410
+ elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
1411
+ placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
1412
+ placeholder_input = ""
1413
+ use_defaults = True
1414
+ use_default_examples = True
1415
+ task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
1416
+ elif 'mbart-' in model_lower:
1417
+ placeholder_instruction = "The girl has long hair."
1418
+ placeholder_input = ""
1419
+ use_defaults = True
1420
+ use_default_examples = False
1421
+ examples += [
1422
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1423
+ 1.0, 1,
1424
+ False]]
1425
+ elif 'gpt2' in model_lower:
1426
+ placeholder_instruction = "The sky is"
1427
+ placeholder_input = ""
1428
+ prompt_type = prompt_type or 'plain'
1429
+ use_default_examples = True # some will be odd "continuations" but can be ok
1430
+ examples += [
1431
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1432
+ 1.0, 1,
1433
+ False]]
1434
+ task_info = "Auto-complete phrase, code, etc."
1435
+ use_defaults = True
1436
+ else:
1437
+ if chat:
1438
+ placeholder_instruction = "Enter a question or imperative."
1439
+ else:
1440
+ placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1441
+ placeholder_input = ""
1442
+ if model_lower:
1443
+ prompt_type = prompt_type or 'human_bot'
1444
+ else:
1445
+ prompt_type = ''
1446
+ examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
1447
+ stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1, False]]
1448
+ task_info = "No task"
1449
+ if prompt_type == 'instruct':
1450
+ task_info = "Answer question or follow imperative as instruction with optionally input."
1451
+ elif prompt_type == 'plain':
1452
+ task_info = "Auto-complete phrase, code, etc."
1453
+ elif prompt_type == 'human_bot':
1454
+ if chat:
1455
+ task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
1456
+ else:
1457
+ task_info = "Ask question/imperative (input concatenated with instruction)"
1458
+
1459
+ # revert to plain if still nothing
1460
+ prompt_type = prompt_type or 'plain'
1461
+ if use_defaults:
1462
+ temperature = 1.0 if temperature is None else temperature
1463
+ top_p = 1.0 if top_p is None else top_p
1464
+ top_k = 40 if top_k is None else top_k
1465
+ num_beams = num_beams or 1
1466
+ max_new_tokens = max_new_tokens or 128
1467
+ repetition_penalty = repetition_penalty or 1.07
1468
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1469
+ do_sample = False if do_sample is None else do_sample
1470
+ else:
1471
+ temperature = 0.1 if temperature is None else temperature
1472
+ top_p = 0.75 if top_p is None else top_p
1473
+ top_k = 40 if top_k is None else top_k
1474
+ if chat:
1475
+ num_beams = num_beams or 1
1476
+ else:
1477
+ num_beams = num_beams or 4
1478
+ max_new_tokens = max_new_tokens or 256
1479
+ repetition_penalty = repetition_penalty or 1.07
1480
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1481
+ do_sample = False if do_sample is None else do_sample
1482
+ params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1483
+ early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
1484
+
1485
+ if use_default_examples:
1486
+ examples += [
1487
+ ["Translate English to French", "Good morning"] + params_list,
1488
+ ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
1489
+ ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
1490
+ [
1491
+ "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
1492
+ ''] + params_list,
1493
+ ['Translate to German: My name is Arthur', ''] + params_list,
1494
+ ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
1495
+ ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
1496
+ ''] + params_list,
1497
+ ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
1498
+ ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
1499
+ ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
1500
+ [
1501
+ "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
1502
+ ''] + params_list,
1503
+ ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
1504
+ [
1505
+ 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
1506
+ ''] + params_list,
1507
+ ["""def area_of_rectangle(a: float, b: float):
1508
+ \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
1509
+ ["""# a function in native python:
1510
+ def mean(a):
1511
+ return sum(a)/len(a)
1512
+
1513
+ # the same function using numpy:
1514
+ import numpy as np
1515
+ def mean(a):""", ''] + params_list,
1516
+ ["""X = np.random.randn(100, 100)
1517
+ y = np.random.randint(0, 1, 100)
1518
+
1519
+ # fit random forest classifier with 20 estimators""", ''] + params_list,
1520
+ ]
1521
+
1522
+ src_lang = "English"
1523
+ tgt_lang = "Russian"
1524
+
1525
+ return placeholder_instruction, placeholder_input, \
1526
+ stream_output, show_examples, \
1527
+ prompt_type, temperature, top_p, top_k, num_beams, \
1528
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
1529
+ repetition_penalty, num_return_sequences, \
1530
+ do_sample, \
1531
+ src_lang, tgt_lang, \
1532
+ examples, \
1533
+ task_info
1534
+
1535
+
1536
+ def languages_covered():
1537
+ # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
1538
+ covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
1539
+ covered = covered.split(', ')
1540
+ covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
1541
+ return covered
1542
+
1543
+
1544
+ def test_test_prompt(prompt_type='instruct', data_point=0):
1545
+ example_data_point = example_data_points[data_point]
1546
+ example_data_point.pop('output', None)
1547
+ return generate_prompt(example_data_point, prompt_type, False, False)
1548
+
1549
+
1550
+ if __name__ == "__main__":
1551
+ print("""
1552
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1553
+ python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1554
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
1555
+
1556
+ # generate without lora weights, no prompt
1557
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
1558
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
1559
+
1560
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
1561
+ # OpenChatKit settings:
1562
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
1563
+
1564
+ python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
1565
+ python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
1566
+ python generate.py --base_model='philschmid/bart-large-cnn-samsum'
1567
+ python generate.py --base_model='philschmid/flan-t5-base-samsum'
1568
+ python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
1569
+
1570
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
1571
+
1572
+ """, flush=True)
1573
+ fire.Fire(main)
client_test.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client test. Simplest case is chat=False and stream_output=False
3
+
4
+ Run server with same choices:
5
+
6
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b --chat=False --stream_output=False
7
+
8
+ NOTE: For private models, add --use-auth_token=True
9
+
10
+ NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
+ Currently, this will force model to be on a single GPU.
12
+
13
+ Then run this client as:
14
+
15
+ python client_test.py
16
+ """
17
+
18
+ debug = False
19
+
20
+ import time
21
+ import os
22
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
23
+ from gradio_client import Client
24
+
25
+ client = Client("http://localhost:7860")
26
+ if debug:
27
+ print(client.view_api(all_endpoints=True))
28
+
29
+ instruction = "Who are you?"
30
+ iinput = ''
31
+ context = ''
32
+ # streaming output is supported, loops over and outputs each generation in streaming mode
33
+ # but leave stream_output=False for simple input/output mode
34
+ stream_output = False
35
+ prompt_type = 'human_bot'
36
+ temperature = 0.1
37
+ top_p = 0.75
38
+ top_k = 40
39
+ num_beams = 1
40
+ max_new_tokens = 500
41
+ min_new_tokens = 0
42
+ early_stopping = False
43
+ max_time = 180
44
+ repetition_penalty = 1.0
45
+ num_return_sequences = 1
46
+ do_sample = True
47
+
48
+ # CHOOSE: must match server
49
+ # NOTE chat mode works through files on gradio
50
+ # and client currently would have to work through those files
51
+ # in tmp, so not best for client. So default to False
52
+ chat = False
53
+
54
+
55
+ def test_client_basic():
56
+ args = [instruction,
57
+ iinput,
58
+ context,
59
+ stream_output,
60
+ prompt_type,
61
+ temperature,
62
+ top_p,
63
+ top_k,
64
+ num_beams,
65
+ max_new_tokens,
66
+ min_new_tokens,
67
+ early_stopping,
68
+ max_time,
69
+ repetition_penalty,
70
+ num_return_sequences,
71
+ do_sample]
72
+
73
+ if not chat:
74
+ # requires generate.py to run with --chat=False
75
+ api_name = '/submit'
76
+ res = client.predict(
77
+ *tuple(args),
78
+ api_name=api_name,
79
+ )
80
+ print(md_to_text(res))
81
+ else:
82
+ api_name = '/instruction'
83
+ import json
84
+ foofile = '/tmp/foo.json'
85
+ with open(foofile, 'wt') as f:
86
+ json.dump([['', None]], f)
87
+ args += [foofile]
88
+ if not stream_output:
89
+ for res in client.predict(
90
+ *tuple(args),
91
+ api_name=api_name,
92
+ ):
93
+ print(res)
94
+ res_file = client.predict(*tuple(args), api_name='/instruction_bot')
95
+ res = json.load(open(res_file, "rt"))[-1][-1]
96
+ print(md_to_text(res))
97
+ else:
98
+ print("streaming instruction_bot", flush=True)
99
+ job = client.submit(*tuple(args), api_name='/instruction_bot')
100
+ while not job.done():
101
+ outputs_list = job.communicator.job.outputs
102
+ if outputs_list:
103
+ res_file = job.communicator.job.outputs[-1]
104
+ res = json.load(open(res_file, "rt"))[-1][-1]
105
+ print(md_to_text(res))
106
+ time.sleep(0.1)
107
+ print(job.outputs())
108
+
109
+
110
+ import markdown # pip install markdown
111
+ from bs4 import BeautifulSoup # pip install beautifulsoup4
112
+
113
+
114
+ def md_to_text(md):
115
+ html = markdown.markdown(md)
116
+ soup = BeautifulSoup(html, features='html.parser')
117
+ return soup.get_text()
118
+
119
+
120
+ if __name__ == '__main__':
121
+ test_client_basic()
finetune.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import random
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from datetime import datetime
9
+ from typing import List, Union
10
+ import fire
11
+ import numpy as np
12
+ import torch
13
+ from datasets import load_dataset, concatenate_datasets
14
+ import transformers
15
+ import torch.distributed as dist
16
+
17
+ from peft import (
18
+ prepare_model_for_int8_training,
19
+ LoraConfig,
20
+ get_peft_model,
21
+ get_peft_model_state_dict,
22
+ set_peft_model_state_dict,
23
+ )
24
+
25
+ from peft import mapping
26
+ lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
27
+
28
+
29
+ def log(*args, **kwargs):
30
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
31
+ print(*args, **kwargs)
32
+
33
+
34
+ try:
35
+ import neptune
36
+ from transformers.integrations import NeptuneCallback
37
+
38
+ neptune_run = neptune.init_run(
39
+ source_files=[],
40
+ )
41
+ log("Connected to Neptune.")
42
+ except ImportError:
43
+ neptune_run = None
44
+ log("Please pip install neptune for tracking.")
45
+ except neptune.exceptions.NeptuneMissingApiTokenException:
46
+ neptune_run = None
47
+ os.environ["NEPTUNE_MODE"] = 'debug'
48
+ log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
49
+
50
+ from enum import Enum
51
+
52
+
53
+ class PromptType(Enum):
54
+ plain = 0
55
+ instruct = 1
56
+ quality = 2
57
+ human_bot = 3
58
+ dai_faq = 4
59
+ summarize = 5
60
+ simple_instruct = 6
61
+ instruct_vicuna = 7
62
+ instruct_with_end = 8
63
+ human_bot_orig = 9
64
+
65
+
66
+ prompt_type_to_model_name = {
67
+ 'plain': [
68
+ 'EleutherAI/gpt-j-6B',
69
+ 'EleutherAI/pythia-6.9b',
70
+ 'EleutherAI/pythia-12b',
71
+ 'EleutherAI/pythia-12b-deduped',
72
+ 'EleutherAI/gpt-neox-20b',
73
+ 'decapoda-research/llama-7b-hf',
74
+ 'decapoda-research/llama-13b-hf',
75
+ 'decapoda-research/llama-30b-hf',
76
+ 'facebook/mbart-large-50-many-to-many-mmt',
77
+ 'philschmid/bart-large-cnn-samsum',
78
+ 'philschmid/flan-t5-base-samsum',
79
+ 'gpt2',
80
+ 'distilgpt2',
81
+ ],
82
+ 'instruct': [],
83
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
84
+ 'quality': [],
85
+ 'human_bot': [
86
+ 'h2oai/h2ogpt-oig-oasst1-256-12b',
87
+ 'h2oai/h2ogpt-oasst1-512-12b',
88
+ 'h2oai/h2ogpt-oasst1-256-20b',
89
+ 'h2oai/h2ogpt-oasst1-512-20b',
90
+ 'h2oai/h2ogpt-oig-oasst1-256-6.9b',
91
+ ],
92
+ 'dai_faq': [],
93
+ 'summarize': [],
94
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
95
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
96
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
97
+ }
98
+
99
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
100
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
101
+
102
+ human = '<human>:'
103
+ bot = "<bot>:"
104
+
105
+ prompt_types_strings = []
106
+ for p in PromptType:
107
+ prompt_types_strings.extend([p.name])
108
+
109
+
110
+ prompt_types = []
111
+ for p in PromptType:
112
+ prompt_types.extend([p.name, p.value, str(p.value)])
113
+
114
+
115
+ # supported by huggingface evaluate
116
+ supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
117
+
118
+
119
+ def train(
120
+ save_code: bool = False,
121
+ run_id: int = None,
122
+
123
+ base_model: str = 'EleutherAI/gpt-neox-20b',
124
+ # base_model: str = 'EleutherAI/pythia-12b-deduped',
125
+ # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
126
+ # base_model: str = 'decapoda-research/llama-7b-hf',
127
+ # base_model: str = 'decapoda-research/llama-13b-hf',
128
+ # base_model: str = 'decapoda-research/llama-30b-hf',
129
+ # base_model: str = 'EleutherAI/gpt-j-6B',
130
+
131
+ # only needed if base_model is self-exported HF state without tokenizer
132
+ tokenizer_base_model: str = None,
133
+ # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
134
+
135
+ data_path: str = None,
136
+ data_col_dict: dict = None,
137
+ # data_path: str = "./dai_docs.train.json",
138
+ prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
139
+
140
+ valid_path: str = None,
141
+ # valid_path: str = "./dai_docs.valid.json",
142
+
143
+ # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
144
+ data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
145
+ data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
146
+ data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
147
+ data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
148
+
149
+ output_dir: str = None,
150
+
151
+ # LoRA checkpoint continuation
152
+ lora_weights: str = "",
153
+
154
+ # batching training hyperparams
155
+ batch_size: int = 128,
156
+ micro_batch_size: int = 4,
157
+ gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
158
+ fp16=True,
159
+
160
+ # general training hyperparams
161
+ num_epochs: float = 1,
162
+ learning_rate: float = 3e-4,
163
+
164
+ # validation settings
165
+ val_set_size: int = None,
166
+ val_metrics: List[str] = [],
167
+ eval_steps: int = None, # to control eval steps via steps
168
+ eval_epochs: float = None, # to control eval steps via epochs
169
+
170
+ # lora hyperparams
171
+ lora_r: int = 8,
172
+ lora_alpha: int = 16,
173
+ lora_dropout: float = 0.05,
174
+ lora_target_modules: List[str] = None,
175
+ llama_type: bool = None,
176
+
177
+ # llm hyperparams
178
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
179
+ group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
180
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
181
+ cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
182
+
183
+ # torch training params
184
+ ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
185
+ local_files_only: bool = False, # else will download new versions, normally unwanted
186
+ resume_download: bool = True,
187
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
188
+ warmup_steps: int = 100,
189
+ logging_steps: int = 1,
190
+ save_steps: int = None, # must be round multiple of eval_steps
191
+ add_eos_token: bool = False,
192
+ ):
193
+ # allow set token directly
194
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
195
+
196
+ prompt_type = str(prompt_type) # migration from integers
197
+ assert prompt_type in prompt_types
198
+
199
+ world_size = int(os.getenv("WORLD_SIZE", 1))
200
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
201
+ rank = int(os.getenv("RANK", 0))
202
+ print(f"local_rank: {local_rank}")
203
+ print(f"global rank: {rank}")
204
+
205
+ gpus = max(world_size, torch.cuda.device_count())
206
+ run_id = run_id or 0
207
+ if not data_path:
208
+ raise ValueError("No data_path provided")
209
+ if not output_dir:
210
+ output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
211
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
212
+ raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
213
+ else:
214
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
215
+ raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
216
+ device_map = "auto"
217
+
218
+ if save_code:
219
+ copy_code(run_id)
220
+ if tokenizer_base_model is None:
221
+ tokenizer_base_model = base_model
222
+ if llama_type is None:
223
+ llama_type = "llama" in base_model.lower()
224
+ assert (
225
+ base_model
226
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
227
+ gradient_accumulation_steps = batch_size // micro_batch_size
228
+ assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
229
+
230
+ device_map = "auto"
231
+
232
+ locals_dict = locals()
233
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
234
+ log(f"Training model with params:\n{locals_print}")
235
+ log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
236
+
237
+ max_memory = None
238
+ if gpus > 1:
239
+ if ddp:
240
+ log("Distributed: data parallel")
241
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
242
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
243
+ else:
244
+ free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
245
+ max_memory = f"{free_in_GB - 2}GB"
246
+ max_memory = {i: max_memory for i in range(gpus)}
247
+ log("world_size: %d" % world_size)
248
+ log("num_gpus: %d" % gpus)
249
+ log("max mem: %s" % max_memory)
250
+
251
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
252
+
253
+ model = model_loader.from_pretrained(
254
+ base_model,
255
+ load_in_8bit=True,
256
+ device_map=device_map,
257
+ torch_dtype=torch.float16,
258
+ max_memory=max_memory,
259
+ local_files_only=local_files_only,
260
+ resume_download=resume_download,
261
+ use_auth_token=use_auth_token,
262
+ )
263
+ if gpus > 1:
264
+ if not ddp:
265
+ log("model parallel")
266
+ model.is_parallelizable = True
267
+ model.model_parallel = True
268
+
269
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
270
+ local_files_only=local_files_only,
271
+ resume_download=resume_download,
272
+ use_auth_token=use_auth_token)
273
+
274
+ tokenizer.pad_token_id = 0 # different from the eos token
275
+ # when generating, we will use the logits of right-most token to predict the next token
276
+ # so the padding should be on the left,
277
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
278
+ tokenizer.padding_side = "left" # Allow batched inference
279
+
280
+ def tokenize(prompt, add_eos_token=True):
281
+ # there's probably a way to do this with the tokenizer settings
282
+ # but again, gotta move fast
283
+ result = tokenizer(
284
+ prompt,
285
+ truncation=True,
286
+ max_length=cutoff_len,
287
+ padding=False,
288
+ return_tensors=None,
289
+ )
290
+ if (
291
+ result["input_ids"][-1] != tokenizer.eos_token_id
292
+ and len(result["input_ids"]) < cutoff_len
293
+ and add_eos_token
294
+ ):
295
+ result["input_ids"].append(tokenizer.eos_token_id)
296
+ result["attention_mask"].append(1)
297
+
298
+ result["labels"] = result["input_ids"].copy()
299
+
300
+ return result
301
+
302
+ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
303
+ full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
304
+ tokenized_full_prompt = tokenize(full_prompt)
305
+ if not train_on_inputs:
306
+ user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
307
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
308
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
309
+ if add_eos:
310
+ user_prompt_len -= 1
311
+
312
+ # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
313
+ tokenized_full_prompt["labels"] = [
314
+ -100
315
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
316
+ user_prompt_len:
317
+ ] # could be sped up, probably
318
+ return tokenized_full_prompt
319
+
320
+ if "gpt-neox" not in base_model or True:
321
+ model = prepare_model_for_int8_training(model)
322
+ else:
323
+ model = prepare_model_for_int8_training(
324
+ model,
325
+ output_embedding_layer_name="embed_out", # keep output logits in float32
326
+ layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
327
+ )
328
+ if lora_weights:
329
+ from peft import PeftModel
330
+ model = PeftModel.from_pretrained(
331
+ model,
332
+ lora_weights,
333
+ torch_dtype=torch.float16,
334
+ device_map=device_map,
335
+ local_files_only=local_files_only,
336
+ resume_download=resume_download,
337
+ use_auth_token=use_auth_token,
338
+ )
339
+ else:
340
+ if lora_target_modules is None:
341
+ base_model_lower = base_model.lower()
342
+ if base_model_lower in lora_mappings:
343
+ lora_target_modules_cand = [lora_mappings[base_model_lower]]
344
+ else:
345
+ lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
346
+ else:
347
+ lora_target_modules_cand = [lora_target_modules]
348
+
349
+ for lora_target_modules in lora_target_modules_cand:
350
+ try:
351
+ config = LoraConfig(
352
+ r=lora_r,
353
+ lora_alpha=lora_alpha,
354
+ target_modules=lora_target_modules,
355
+ lora_dropout=lora_dropout,
356
+ bias="none",
357
+ task_type="CAUSAL_LM",
358
+ )
359
+ model = get_peft_model(model, config)
360
+ break
361
+ except ValueError as e:
362
+ if "Target modules" in str(e) and "not found" in str(e):
363
+ continue
364
+ else:
365
+ raise
366
+ from peft import PeftModel
367
+ assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
368
+ if resume_from_checkpoint:
369
+ # Check the available weights and load them
370
+ checkpoint_name = os.path.join(
371
+ resume_from_checkpoint, "pytorch_model.bin"
372
+ ) # Full checkpoint
373
+ if not os.path.exists(checkpoint_name):
374
+ checkpoint_name = os.path.join(
375
+ resume_from_checkpoint, "adapter_model.bin"
376
+ ) # only LoRA model - LoRA config above has to fit
377
+ resume_from_checkpoint = False # So the trainer won't try loading its state
378
+ # The two files above have a different name depending on how they were saved, but are actually the same.
379
+ if os.path.exists(checkpoint_name):
380
+ log(f"Restarting from {checkpoint_name}")
381
+ adapters_weights = torch.load(checkpoint_name)
382
+ model = set_peft_model_state_dict(model, adapters_weights)
383
+ else:
384
+ log(f"Checkpoint {checkpoint_name} not found")
385
+
386
+ print(model)
387
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
388
+
389
+ metrics = {}
390
+ for name in supported_metrics:
391
+ if name in val_metrics:
392
+ import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
393
+ metrics[name] = evaluate.load(name)
394
+ log("Using Validation Metrics: %s" % str(list(metrics.keys())))
395
+ log("Supported Metrics: %s" % supported_metrics)
396
+
397
+ if val_set_size is None:
398
+ if len(metrics) == 0:
399
+ val_set_size = 1000
400
+ else:
401
+ val_set_size = 100
402
+ log("Auto set val_set_size %s" % val_set_size)
403
+ elif val_set_size < 1.0 and val_set_size != 0:
404
+ raise RuntimeError("Fractional validation size not supported.")
405
+
406
+ if valid_path:
407
+ data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
408
+ else:
409
+ if "json" in data_path:
410
+ data = load_dataset("json", data_files={"train": data_path})
411
+ else:
412
+ data = load_dataset(data_path)
413
+ data = data.rename_columns(data_col_dict or {})
414
+
415
+ valid_data = None
416
+ train_data_mix_in = None
417
+ valid_data_mix_in = None
418
+
419
+ if data_mix_in_path and data_mix_in_factor > 0:
420
+ # get mix-in training/validation data - to keep model "sane"
421
+ num_rows = data["train"].num_rows
422
+ log("Loading mix-in dataset: %s" % data_mix_in_path)
423
+ if "json" in data_mix_in_path:
424
+ data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
425
+ else:
426
+ data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
427
+ data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
428
+
429
+ # only get as much as we need to balance
430
+ valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
431
+ train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
432
+ mixin_small = data_mix_in.train_test_split(
433
+ test_size=train_size + valid_size,
434
+ shuffle=True, seed=np.random.randint(10000),
435
+ )["test"]
436
+ if valid_size:
437
+ mixin_train_test = mixin_small.train_test_split(
438
+ test_size=valid_size, shuffle=False,
439
+ )
440
+ train_data_mix_in = mixin_train_test["train"]
441
+ valid_data_mix_in = mixin_train_test["test"]
442
+ else:
443
+ train_data_mix_in = mixin_small
444
+
445
+ if "prompt_type" not in train_data_mix_in.column_names:
446
+ train_data_mix_in = train_data_mix_in.add_column(
447
+ "prompt_type",
448
+ [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
449
+ )
450
+ log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
451
+ if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
452
+ valid_data_mix_in = valid_data_mix_in.add_column(
453
+ "prompt_type",
454
+ [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
455
+ )
456
+ log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
457
+ log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
458
+
459
+ # get our own training/validation data - for fine-tuning
460
+ if val_set_size > 0 and not valid_path and not data_mix_in_path:
461
+ # create valid split from train
462
+ train_val = data["train"].train_test_split(
463
+ test_size=val_set_size, shuffle=True, seed=42
464
+ )
465
+ train_data = train_val["train"]
466
+ valid_data = train_val["test"]
467
+ else:
468
+ train_data = data["train"]
469
+ if valid_path:
470
+ # use given valid split, has priority over data_mix_in_path
471
+ valid_data = data["valid"]
472
+ if "prompt_type" not in train_data.column_names:
473
+ train_data = train_data.add_column(
474
+ "prompt_type",
475
+ [prompt_type] * train_data.num_rows,
476
+ )
477
+ log("Added prompt type %s to training data" % prompt_type)
478
+ if valid_data and "prompt_type" not in valid_data.column_names:
479
+ valid_data = valid_data.add_column(
480
+ "prompt_type",
481
+ [prompt_type] * valid_data.num_rows,
482
+ )
483
+ log("Added prompt type %s to validation data" % prompt_type)
484
+
485
+ assert train_data is not None
486
+
487
+ # shuffle and tokenize data
488
+ if train_data_mix_in:
489
+ train_data = concatenate_datasets([train_data, train_data_mix_in])
490
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
491
+ train_set_size = len(train_data)
492
+
493
+ if valid_data and valid_data_mix_in:
494
+ valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
495
+ elif valid_data_mix_in:
496
+ valid_data = valid_data_mix_in
497
+
498
+ if valid_data:
499
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
500
+ val_set_size = len(valid_data)
501
+ else:
502
+ val_set_size = 0
503
+ log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
504
+ sample_row_dict = train_data[:1]
505
+ del sample_row_dict['input_ids']
506
+ del sample_row_dict['attention_mask']
507
+ del sample_row_dict['labels']
508
+ log("Sample input: %s" % sample_row_dict)
509
+
510
+ if neptune_run:
511
+ neptune_callback = NeptuneCallback(run=neptune_run)
512
+ callbacks = [neptune_callback]
513
+ else:
514
+ from transformers.integrations import TensorBoardCallback, is_tensorboard_available
515
+ if is_tensorboard_available:
516
+ # tensorboard --logdir=runs/
517
+ from torch.utils.tensorboard import SummaryWriter
518
+ tb_writer = SummaryWriter()
519
+ callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
520
+ else:
521
+ callbacks = []
522
+
523
+ expected_steps = (train_set_size * num_epochs) // batch_size
524
+ if eval_steps is None and eval_epochs is None:
525
+ # 20 evaluations for a run
526
+ eval_steps = max(1, int(expected_steps / 20))
527
+ log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
528
+ elif eval_steps is None and eval_epochs is not None:
529
+ eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
530
+ log("Auto converted eval_epochs=%s to eval_steps %s"
531
+ " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
532
+ if save_steps is None:
533
+ save_steps = eval_steps
534
+ log("Auto step save_steps to %s" % save_steps)
535
+ elif save_steps > eval_steps:
536
+ # save steps must be round multiple of eval_steps
537
+ save_steps0 = save_steps
538
+ save_steps = max(1, (save_steps//eval_steps)) * eval_steps
539
+ if save_steps0 != save_steps:
540
+ log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
541
+
542
+ def compute_metrics(eval_preds):
543
+ # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
544
+ inputs = eval_preds.inputs
545
+ label_ids = eval_preds.label_ids
546
+ predictions = eval_preds.predictions
547
+
548
+ #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
549
+ #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
550
+ #decoded_inputs = [pred.strip() for pred in decoded_inputs]
551
+
552
+ label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
553
+ # tokenizer behavior like generate time
554
+ decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
555
+ clean_up_tokenization_spaces=True)
556
+ decoded_labels = [pred.strip() for pred in decoded_labels]
557
+
558
+ predictions = np.argmax(predictions, -1)
559
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
560
+ # tokenizer behavior like generate time
561
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
562
+ clean_up_tokenization_spaces=True)
563
+ decoded_predictions = [pred.strip() for pred in decoded_predictions]
564
+
565
+ result = {}
566
+ for metric in metrics.values():
567
+ result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
568
+ # get rid of lists, for precision etc., for now
569
+ numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
570
+ result.update(numeric_results)
571
+ return result
572
+
573
+ # the callback that computes metrics of interest
574
+ if val_metrics:
575
+ trainer_kwargs = dict(compute_metrics=compute_metrics)
576
+ else:
577
+ trainer_kwargs = dict()
578
+
579
+ trainer = transformers.Trainer(
580
+ model=model,
581
+ tokenizer=tokenizer,
582
+ train_dataset=train_data,
583
+ eval_dataset=valid_data,
584
+ # NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
585
+ args=transformers.Seq2SeqTrainingArguments(
586
+ per_device_train_batch_size=micro_batch_size,
587
+ per_device_eval_batch_size=1,
588
+ eval_accumulation_steps=10,
589
+ # predict_with_generate=True, # SEQ2SEQ only
590
+ include_inputs_for_metrics=True,
591
+ gradient_accumulation_steps=gradient_accumulation_steps,
592
+ warmup_steps=warmup_steps,
593
+ num_train_epochs=num_epochs,
594
+ learning_rate=learning_rate,
595
+ gradient_checkpointing=gradient_checkpointing,
596
+ fp16=fp16,
597
+ # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
598
+ optim="adamw_torch", # consider "adafactor" to save memory
599
+ logging_steps=logging_steps,
600
+ logging_strategy="steps",
601
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
602
+ save_strategy="steps",
603
+ eval_steps=eval_steps if val_set_size > 0 else None,
604
+ save_steps=save_steps,
605
+ output_dir=output_dir,
606
+ save_total_limit=3,
607
+ load_best_model_at_end=True if val_set_size > 0 else False,
608
+ ddp_find_unused_parameters=False if ddp else None,
609
+ group_by_length=group_by_length,
610
+ #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
611
+ #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
612
+ report_to='tensorboard' if not neptune_run else 'neptune',
613
+ ),
614
+ data_collator=transformers.DataCollatorForSeq2Seq(
615
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
616
+ ),
617
+ callbacks=callbacks,
618
+ **trainer_kwargs,
619
+ )
620
+ model.config.use_cache = False
621
+
622
+ old_state_dict = model.state_dict
623
+ model.state_dict = (
624
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
625
+ ).__get__(model, type(model))
626
+
627
+ if torch.__version__ >= "2" and sys.platform != "win32":
628
+ model = torch.compile(model)
629
+ # WIP (not generally replacing layers until pytorch 2.1)
630
+ torch.backends.cuda.enable_flash_sdp(True)
631
+
632
+ if gpus > 1 and not ddp:
633
+ assert trainer.is_model_parallel
634
+ else:
635
+ assert not trainer.is_model_parallel
636
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
637
+
638
+ model.save_pretrained(output_dir)
639
+
640
+ log("\n If there's a warning about missing keys above, please disregard :)")
641
+
642
+
643
+ def get_loaders(llama_type, model_name, reward_type):
644
+ # NOTE: Some models need specific new prompt_type
645
+ # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
646
+ if llama_type:
647
+ from transformers import LlamaForCausalLM, LlamaTokenizer
648
+ model_loader = LlamaForCausalLM
649
+ tokenizer_loader = LlamaTokenizer
650
+ elif 'gpt2' in model_name.lower():
651
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
652
+ return GPT2LMHeadModel, GPT2Tokenizer
653
+ elif 'mbart-' in model_name.lower():
654
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
655
+ return MBartForConditionalGeneration, MBart50TokenizerFast
656
+ elif 't5' == model_name.lower() or \
657
+ 't5-' in model_name.lower() or \
658
+ 'flan-' in model_name.lower():
659
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
660
+ return T5ForConditionalGeneration, AutoTokenizer
661
+ elif 'bigbird' in model_name:
662
+ from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
663
+ return BigBirdPegasusForConditionalGeneration, AutoTokenizer
664
+ elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
665
+ from transformers import pipeline
666
+ return pipeline, "summarization"
667
+ elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
668
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
669
+ return AutoModelForSequenceClassification, AutoTokenizer
670
+ else:
671
+ from transformers import AutoTokenizer, AutoModelForCausalLM
672
+ model_loader = AutoModelForCausalLM
673
+ tokenizer_loader = AutoTokenizer
674
+ return model_loader, tokenizer_loader
675
+
676
+
677
+ def get_githash():
678
+ try:
679
+ githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
680
+ except:
681
+ githash = ''
682
+ return githash
683
+
684
+
685
+ def copy_code(run_id):
686
+ """
687
+ copy code to track changes
688
+ :param run_id:
689
+ :return:
690
+ """
691
+ rnd_num = str(random.randint(0, 2 ** 31))
692
+ run_id = 'run_' + str(run_id)
693
+ os.makedirs(run_id, exist_ok=True)
694
+ me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
695
+ me_file = os.path.basename(__file__)
696
+ new_me = os.path.join(run_id, me_file + '_' + get_githash())
697
+ if os.path.isfile(new_me):
698
+ new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
699
+ shutil.copy(me_full, new_me)
700
+ else:
701
+ shutil.copy(me_full, new_me)
702
+
703
+
704
+ def get_prompt(prompt_type, chat, context, reduced):
705
+ if prompt_type in [-1, "-1", "plain"]:
706
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
707
+ terminate_response = []
708
+ elif prompt_type == 'simple_instruct':
709
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
710
+ terminate_response = []
711
+ elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
712
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
713
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
714
+
715
+ PreInstruct = """
716
+ ### Instruction:
717
+ """
718
+
719
+ PreInput = """
720
+ ### Input:
721
+ """
722
+
723
+ PreResponse = """
724
+ ### Response:
725
+ """
726
+ if prompt_type in [7, "7", "instruct_with_end"]:
727
+ terminate_response = ['### End']
728
+ else:
729
+ terminate_response = None
730
+ elif prompt_type in [1, "1", "quality"]:
731
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
732
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
733
+
734
+ PreInstruct = """
735
+ ### Instruction:
736
+ """
737
+
738
+ PreInput = """
739
+ ### Input:
740
+ """
741
+
742
+ PreResponse = """
743
+ ### Response:
744
+ """
745
+ terminate_response = None
746
+ elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
747
+ if reduced or context or prompt_type in [2, "2", "human_bot"]:
748
+ preprompt = ''
749
+ else:
750
+ cur_date = time.strftime('%Y-%m-%d')
751
+ cur_time = time.strftime('%H:%M:%S %p %Z')
752
+
753
+ PRE_PROMPT = """\
754
+ Current Date: {}
755
+ Current Time: {}
756
+
757
+ """
758
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
759
+ start = human
760
+ promptB = promptA = '%s%s ' % (preprompt, start)
761
+
762
+ PreInstruct = ""
763
+
764
+ PreInput = None
765
+
766
+ PreResponse = bot
767
+
768
+ terminate_response = [start, PreResponse]
769
+ elif prompt_type in [3, "3", "dai_faq"]:
770
+ promptA = ''
771
+ promptB = 'Answer the following Driverless AI question.\n'
772
+
773
+ PreInstruct = """
774
+ ### Driverless AI frequently asked question:
775
+ """
776
+
777
+ PreInput = None
778
+
779
+ PreResponse = """
780
+ ### Driverless AI documentation answer:
781
+ """
782
+ terminate_response = ['\n\n']
783
+ elif prompt_type in [5, "5", "summarize"]:
784
+ promptA = promptB = PreInput = ''
785
+ PreInstruct = '## Main Text\n\n'
786
+ PreResponse = '\n\n## Summary\n\n'
787
+ terminate_response = None
788
+ elif prompt_type in [6, "6", "instruct_vicuna"]:
789
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
790
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
791
+
792
+ PreInstruct = """
793
+ ### Human:
794
+ """
795
+
796
+ PreInput = None
797
+
798
+ PreResponse = """
799
+ ### Assistant:
800
+ """
801
+ terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
802
+ else:
803
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
804
+
805
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
806
+
807
+
808
+ def generate_prompt(data_point, prompt_type, chat, reduced):
809
+ context = data_point.get('context') if chat else ''
810
+ if context is None:
811
+ context = ''
812
+ instruction = data_point.get('instruction')
813
+ input = data_point.get('input')
814
+ output = data_point.get('output')
815
+ prompt_type = data_point.get('prompt_type', prompt_type)
816
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
817
+ promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
818
+
819
+ prompt = context
820
+
821
+ if input and promptA:
822
+ prompt += f"""{promptA}"""
823
+ elif promptB:
824
+ prompt += f"""{promptB}"""
825
+
826
+ if instruction and PreInstruct is not None and input and PreInput is not None:
827
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
828
+ prompt = inject_newline(prompt_type, prompt)
829
+ elif instruction and input and PreInstruct is None and PreInput is not None:
830
+ prompt += f"""{PreInput}{instruction}
831
+ {input}"""
832
+ prompt = inject_newline(prompt_type, prompt)
833
+ elif input and instruction and PreInput is None and PreInstruct is not None:
834
+ prompt += f"""{PreInstruct}{instruction}
835
+ {input}"""
836
+ prompt = inject_newline(prompt_type, prompt)
837
+ elif instruction and PreInstruct is not None:
838
+ prompt += f"""{PreInstruct}{instruction}"""
839
+ prompt = inject_newline(prompt_type, prompt)
840
+ elif input and PreInput is not None:
841
+ prompt += f"""{PreInput}{input}"""
842
+ prompt = inject_newline(prompt_type, prompt)
843
+ elif input and instruction and PreInput is not None:
844
+ prompt += f"""{PreInput}{instruction}{input}"""
845
+ prompt = inject_newline(prompt_type, prompt)
846
+ elif input and instruction and PreInstruct is not None:
847
+ prompt += f"""{PreInstruct}{instruction}{input}"""
848
+ prompt = inject_newline(prompt_type, prompt)
849
+ elif input and instruction:
850
+ # i.e. for simple_instruct
851
+ prompt += f"""{instruction}: {input}"""
852
+ prompt = inject_newline(prompt_type, prompt)
853
+ elif input:
854
+ prompt += f"""{input}"""
855
+ prompt = inject_newline(prompt_type, prompt)
856
+ elif instruction:
857
+ prompt += f"""{instruction}"""
858
+ prompt = inject_newline(prompt_type, prompt)
859
+
860
+ if PreResponse is not None:
861
+ prompt += f"""{PreResponse}"""
862
+ pre_response = PreResponse # Don't use strip
863
+ else:
864
+ pre_response = ''
865
+
866
+ if output:
867
+ prompt += f"""{output}"""
868
+
869
+ return prompt, pre_response, terminate_response
870
+
871
+
872
+ def inject_newline(prompt_type, prompt):
873
+ if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
874
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
875
+ prompt += '\n'
876
+ return prompt
877
+
878
+
879
+ example_data_point0 = dict(instruction="Summarize",
880
+ input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
881
+ output="Ducks eat and swim at the lake.")
882
+
883
+ example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
884
+ output="Einstein.")
885
+
886
+ example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
887
+ output="Einstein.")
888
+
889
+ example_data_points = [example_data_point0, example_data_point1, example_data_point2]
890
+
891
+
892
+ def test_train_prompt(prompt_type='instruct', data_point=0):
893
+ example_data_point = example_data_points[data_point]
894
+ return generate_prompt(example_data_point, prompt_type, False, False)
895
+
896
+
897
+ def test_debug():
898
+ fire.Fire(train)
899
+
900
+
901
+ if __name__ == "__main__":
902
+ CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
903
+ CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
904
+ log(f"""
905
+ Example runs on 4 GPUs:
906
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
907
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
908
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
909
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
910
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
911
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
912
+
913
+ All metrics:
914
+ CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
915
+
916
+ # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
917
+ rippa>
918
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
919
+ ova>
920
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
921
+ timemachine>
922
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
923
+
924
+ """, flush=True)
925
+
926
+ if os.environ.get("LOCAL_RANK") is None:
927
+ # then not using torchrun, so can't do distributed, ensure CVD set
928
+ assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
929
+
930
+ fire.Fire(train)
h2o-logo.svg ADDED
prompter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from finetune import generate_prompt
2
+
3
+
4
+ class Prompter(object):
5
+ def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
6
+ allowed_repeat_line_length=10):
7
+ self.prompt_type = prompt_type
8
+ data_point = dict(instruction='', input='', output='')
9
+ _, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
10
+ self.debug = debug
11
+ self.chat = chat
12
+ self.stream_output = stream_output
13
+ self.repeat_penalty = repeat_penalty
14
+ self.allowed_repeat_line_length = allowed_repeat_line_length
15
+
16
+ def generate_prompt(self, data_point):
17
+ reduced = False
18
+ prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
19
+ if self.debug:
20
+ print("prompt: ", prompt, flush=True)
21
+ self.prompt = prompt
22
+ return prompt
23
+
24
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
25
+ if isinstance(outputs, str):
26
+ outputs = [outputs]
27
+ if self.debug:
28
+ print("output: ", '\n\n'.join(outputs), flush=True)
29
+ if prompt is not None:
30
+ self.prompt = prompt
31
+
32
+ def clean_response(response):
33
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
34
+ for word in meaningless_words:
35
+ response = response.replace(word, "")
36
+ if sanitize_bot_response:
37
+ from better_profanity import profanity
38
+ response = profanity.censor(response)
39
+ response = response.strip("\n")
40
+ return response
41
+
42
+ def clean_repeats(response):
43
+ lines = response.split('\n')
44
+ new_lines = []
45
+ [new_lines.append(line) for line in lines if
46
+ line not in new_lines or len(line) < self.allowed_repeat_line_length]
47
+ if self.debug and len(lines) != len(new_lines):
48
+ print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
49
+ response = '\n'.join(new_lines)
50
+ return response
51
+
52
+ multi_output = len(outputs) > 1
53
+
54
+ for oi, output in enumerate(outputs):
55
+ if self.prompt_type in [0, '0', 'plain']:
56
+ output = clean_response(output)
57
+ else:
58
+ # find first instance of prereponse
59
+ # prompt sometimes has odd characters, that mutate length,
60
+ # so can't go by length alone
61
+ if self.pre_response:
62
+ outputi = output.find(prompt)
63
+ if outputi >= 0:
64
+ output = output[outputi + len(prompt):]
65
+ allow_terminate = True
66
+ else:
67
+ # subtraction is risky due to space offsets sometimes, so only do if necessary
68
+ output = output[len(prompt) - len(self.pre_response):]
69
+ # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
70
+ if self.pre_response in output:
71
+ output = output.split(self.pre_response)[1]
72
+ allow_terminate = True
73
+ else:
74
+ print("Failure of parsing: %s" % output, flush=True)
75
+ allow_terminate = False
76
+ else:
77
+ allow_terminate = True
78
+ output = output[len(prompt):]
79
+ # clean after subtract prompt out, so correct removal of pre_response
80
+ output = clean_response(output).strip()
81
+ if self.repeat_penalty:
82
+ output = clean_repeats(output).strip()
83
+ if self.terminate_response and allow_terminate:
84
+ finds = []
85
+ for term in self.terminate_response:
86
+ finds.append(output.find(term))
87
+ finds = [x for x in finds if x >= 0]
88
+ if len(finds) > 0:
89
+ termi = finds[0]
90
+ output = output[:termi].strip()
91
+ else:
92
+ output = output.strip()
93
+ else:
94
+ output = output.strip()
95
+ if multi_output:
96
+ # prefix with output counter
97
+ output = "\n=========== Output %d\n\n" % (1 + oi) + output
98
+ if oi > 0:
99
+ # post fix outputs with seperator
100
+ output += '\n'
101
+ outputs[oi] = output
102
+ # join all outputs, only one extra new line between outputs
103
+ output = '\n'.join(outputs)
104
+ if self.debug:
105
+ print("outputclean: ", '\n\n'.join(outputs), flush=True)
106
+ return output
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for generate (gradio server) and finetune
2
+ datasets==2.10.1
3
+ sentencepiece==0.1.97
4
+ accelerate==0.18.0
5
+ gradio==3.27.0
6
+ huggingface_hub==0.13.4
7
+ appdirs==1.4.4
8
+ fire==0.5.0
9
+ docutils==0.19
10
+ torch==2.0.0
11
+ evaluate==0.4.0
12
+ rouge_score==0.1.2
13
+ sacrebleu==2.3.1
14
+ scikit-learn==1.2.2
15
+ alt-profanity-check==1.2.2
16
+ better-profanity==0.6.1
17
+ numpy==1.24.2
18
+ pandas==1.5.3
19
+ matplotlib==3.7.1
20
+ loralib==0.1.1
21
+ bitsandbytes==0.38.1
22
+ git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
23
+ transformers==4.28.1
24
+ tokenizers==0.13.3
25
+
26
+ # optional for generate
27
+ pynvml==11.5.0
28
+ psutil==5.9.4
29
+
30
+ # optional for finetune
31
+ tensorboard==2.12.1
32
+ neptune==1.1.1
33
+
34
+ # for gradio client
35
+ gradio_client==0.1.3
36
+ beautifulsoup4==4.12.2
37
+ markdown==3.4.1
38
+
39
+ # data and testing
40
+ pytest==7.2.2
41
+ pytest-xdist==3.2.1
42
+ nltk==3.8.1
43
+ textstat==0.7.3
44
+ pandoc==2.3
45
+ pypandoc==1.11
46
+ openpyxl==3.1.2
47
+ lm_dataformat==0.0.20
48
+ bioc==2.0
stopping.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from queue import Queue
3
+ from threading import Thread
4
+ import collections.abc
5
+
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+
9
+
10
+ class StoppingCriteriaSub(StoppingCriteria):
11
+
12
+ def __init__(self, stops=[], encounters=[]):
13
+ super().__init__()
14
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
15
+ self.encounters = encounters
16
+ self.stops = [stop.to("cuda") for stop in stops]
17
+ self.num_stops = [0] * len(stops)
18
+
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ for stopi, stop in enumerate(self.stops):
21
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
22
+ self.num_stops[stopi] += 1
23
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
24
+ return True
25
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
+ return False
28
+
29
+
30
+ class Stream(StoppingCriteria):
31
+ """
32
+ This class can be used to callback during generation. Keep
33
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
34
+
35
+ Args:
36
+ func (`callable`):
37
+ A callable function to apply on first input in list every iteration of generation
38
+ """
39
+
40
+ def __init__(self, func=None):
41
+ self.func = func
42
+
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ if self.func is not None:
45
+ # only consume first of multiple responses
46
+ self.func(input_ids[0])
47
+ return False
48
+
49
+
50
+ class CallbackToGenerator(collections.abc.Generator):
51
+ """
52
+ A generator wrapper for a function that invokes a callback multiple times.
53
+
54
+ Calling `send` on the generator emits a value from one callback, and returns
55
+ the next.
56
+
57
+ Note this starts a background thread
58
+ """
59
+
60
+ def __init__(self, func, *args, callback=None, **kwargs):
61
+ self.func = func
62
+ self.args = args
63
+ self.kwargs = kwargs
64
+ self.callback = callback
65
+
66
+ self._ready_queue = Queue(1)
67
+ self._done_queue = Queue(1)
68
+ self._done_holder = [False]
69
+
70
+ # local to avoid reference cycles
71
+ ready_queue = self._ready_queue
72
+ done_queue = self._done_queue
73
+ done_holder = self._done_holder
74
+
75
+ def val_callback(value):
76
+ done_queue.put((False, value))
77
+ cmd, val = ready_queue.get()
78
+ if cmd == 'send':
79
+ return val
80
+ elif cmd == 'throw':
81
+ raise val
82
+ else:
83
+ assert False # pragma: no cover
84
+
85
+ def thread_func():
86
+ while True:
87
+ cmd, val = ready_queue.get()
88
+ if cmd == 'send' and val is not None:
89
+ done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
90
+ continue
91
+ break
92
+ try:
93
+ if cmd == 'throw':
94
+ raise val
95
+ ret = func(callback=val_callback, **self.kwargs)
96
+ raise StopIteration(ret) if ret is not None else StopIteration
97
+ except BaseException as e:
98
+ done_holder[0] = True
99
+ done_queue.put((True, e))
100
+
101
+ self._thread = Thread(target=thread_func)
102
+ self._thread.start()
103
+
104
+ def _put(self, *args):
105
+ if self._done_holder[0]:
106
+ raise StopIteration
107
+ self._ready_queue.put(args)
108
+ is_exception, val = self._done_queue.get()
109
+ if is_exception:
110
+ try:
111
+ raise val
112
+ finally:
113
+ # prevent val's traceback containing a reference cycle
114
+ del val
115
+ else:
116
+ return val
117
+
118
+ def send(self, value):
119
+ return self._put('send', value)
120
+
121
+ def throw(self, exc):
122
+ return self._put('throw', exc)
123
+
124
+ def close(self):
125
+ try:
126
+ self.throw(GeneratorExit)
127
+ except StopIteration:
128
+ self._thread.join()
129
+ except GeneratorExit:
130
+ self._thread.join()
131
+ except BaseException:
132
+ self._thread.join()
133
+ raise
134
+ else:
135
+ # yielded again, can't clean up the thread
136
+ raise RuntimeError('Task with callback ignored GeneratorExit')
137
+
138
+ def __del__(self):
139
+ self.close()
utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import random
4
+ import time
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+
9
+
10
+ def set_seed(seed: int):
11
+ """
12
+ Sets the seed of the entire notebook so results are the same every time we run.
13
+ This is for REPRODUCIBILITY.
14
+ """
15
+ np.random.seed(seed)
16
+ random_state = np.random.RandomState(seed)
17
+ random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
+ os.environ['PYTHONHASHSEED'] = str(seed)
23
+ return random_state
24
+
25
+
26
+ def flatten_list(lis):
27
+ """Given a list, possibly nested to any level, return it flattened."""
28
+ new_lis = []
29
+ for item in lis:
30
+ if type(item) == type([]):
31
+ new_lis.extend(flatten_list(item))
32
+ else:
33
+ new_lis.append(item)
34
+ return new_lis
35
+
36
+
37
+ def clear_torch_cache():
38
+ if torch.cuda.is_available:
39
+ torch.cuda.empty_cache()
40
+ torch.cuda.ipc_collect()
41
+ gc.collect()
42
+
43
+
44
+ def system_info():
45
+ import psutil
46
+
47
+ system = {}
48
+ # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
49
+ # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
50
+ temps = psutil.sensors_temperatures(fahrenheit=False)
51
+ if 'coretemp' in temps:
52
+ coretemp = temps['coretemp']
53
+ temp_dict = {k.label: k.current for k in coretemp}
54
+ for k, v in temp_dict.items():
55
+ system['CPU_C/%s' % k] = v
56
+
57
+ # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
58
+ from pynvml.smi import nvidia_smi
59
+ nvsmi = nvidia_smi.getInstance()
60
+
61
+ gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
62
+ enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
63
+ for k, v in gpu_power_dict.items():
64
+ system['GPU_W/%s' % k] = v
65
+
66
+ gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
67
+ enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
68
+ for k, v in gpu_temp_dict.items():
69
+ system['GPU_C/%s' % k] = v
70
+
71
+ gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
72
+ enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
73
+ gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
74
+ enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
75
+ gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
76
+ for k, v in gpu_memory_frac_dict.items():
77
+ system[f'GPU_M/%s' % k] = v
78
+
79
+ return system
80
+
81
+
82
+ def system_info_print():
83
+ try:
84
+ df = pd.DataFrame.from_dict(system_info(), orient='index')
85
+ # avoid slamming GPUs
86
+ time.sleep(1)
87
+ return df.to_markdown()
88
+ except Exception as e:
89
+ return "Error: %s" % str(e)