Spaces:
Runtime error
Runtime error
Merge branch 'main' into hf-ui-demo
Browse files- LLaMA_LoRA.ipynb +31 -13
- README.md +19 -16
- app.py +2 -1
- llama_lora/globals.py +11 -5
- llama_lora/lib/finetune.py +21 -2
- llama_lora/models.py +103 -59
- llama_lora/ui/finetune_ui.py +111 -31
- llama_lora/ui/inference_ui.py +186 -40
- llama_lora/ui/main_page.py +48 -5
- llama_lora/ui/tokenizer_ui.py +4 -2
- llama_lora/utils/data.py +16 -0
- llama_lora/utils/lru_cache.py +31 -0
- requirements.lock.txt +120 -0
- templates/user_and_ai.json +7 -0
LLaMA_LoRA.ipynb
CHANGED
@@ -6,7 +6,6 @@
|
|
6 |
"provenance": [],
|
7 |
"private_outputs": true,
|
8 |
"toc_visible": true,
|
9 |
-
"authorship_tag": "ABX9TyMHMc4PwWLbRlhFol+WRzoT",
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
@@ -27,13 +26,13 @@
|
|
27 |
"colab_type": "text"
|
28 |
},
|
29 |
"source": [
|
30 |
-
"<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
31 |
]
|
32 |
},
|
33 |
{
|
34 |
"cell_type": "markdown",
|
35 |
"source": [
|
36 |
-
"# 🦙🎛️ LLaMA-LoRA\n",
|
37 |
"\n",
|
38 |
"TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
|
39 |
],
|
@@ -55,6 +54,26 @@
|
|
55 |
"execution_count": null,
|
56 |
"outputs": []
|
57 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
{
|
59 |
"cell_type": "markdown",
|
60 |
"source": [
|
@@ -72,9 +91,9 @@
|
|
72 |
"# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
|
73 |
"# @markdown Project settings.\n",
|
74 |
"\n",
|
75 |
-
"# @markdown The URL of the LLaMA-LoRA project<br> (default: `https://github.com/zetavg/
|
76 |
-
"llama_lora_project_url = \"https://github.com/zetavg/
|
77 |
-
"# @markdown The branch to use for LLaMA-LoRA project:\n",
|
78 |
"llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
|
79 |
"\n",
|
80 |
"# # @markdown Forces the local directory to be updated by the remote branch:\n",
|
@@ -97,7 +116,7 @@
|
|
97 |
"# @markdown You can customize the location of the stored data here.\n",
|
98 |
"\n",
|
99 |
"# @markdown The folder in Google Drive where Colab Notebook data are stored<br /> **(WARNING: The content of this folder will be modified by this notebook)**:\n",
|
100 |
-
"google_drive_folder = \"Colab Data/LLaMA
|
101 |
"# google_drive_colab_data_folder = \"Colab Notebooks/Notebook Data\"\n",
|
102 |
"\n",
|
103 |
"# Where Google Drive will be mounted in the Colab runtime.\n",
|
@@ -220,7 +239,7 @@
|
|
220 |
"source": [
|
221 |
"![ ! -d llama_lora ] && git clone -b {llama_lora_project_branch} --filter=tree:0 {llama_lora_project_url} llama_lora\n",
|
222 |
"!cd llama_lora && git add --all && git stash && git fetch origin {llama_lora_project_branch} && git checkout {llama_lora_project_branch} && git reset origin/{llama_lora_project_branch} --hard\n",
|
223 |
-
"![ ! -f llama-lora-requirements-installed ] && cd llama_lora && pip install -r requirements.txt && touch ../llama-lora-requirements-installed"
|
224 |
],
|
225 |
"metadata": {
|
226 |
"id": "JGYz2VDoAzC8"
|
@@ -262,7 +281,7 @@
|
|
262 |
"\n",
|
263 |
"# Set Configs\n",
|
264 |
"from llama_lora.llama_lora.globals import Global\n",
|
265 |
-
"Global.
|
266 |
"data_dir_realpath = !realpath ./data\n",
|
267 |
"Global.data_dir = data_dir_realpath[0]\n",
|
268 |
"Global.load_8bit = True\n",
|
@@ -270,12 +289,11 @@
|
|
270 |
"# Prepare Data Dir\n",
|
271 |
"import os\n",
|
272 |
"from llama_lora.llama_lora.utils.data import init_data_dir\n",
|
273 |
-
"init_data_dir()
|
274 |
"\n",
|
275 |
"# Load the Base Model\n",
|
276 |
-
"from llama_lora.llama_lora.models import
|
277 |
-
"
|
278 |
-
"print(f\"Base model loaded: '{Global.base_model}'.\")"
|
279 |
],
|
280 |
"metadata": {
|
281 |
"id": "Yf6g248ylteP"
|
|
|
6 |
"provenance": [],
|
7 |
"private_outputs": true,
|
8 |
"toc_visible": true,
|
|
|
9 |
"include_colab_link": true
|
10 |
},
|
11 |
"kernelspec": {
|
|
|
26 |
"colab_type": "text"
|
27 |
},
|
28 |
"source": [
|
29 |
+
"<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
]
|
31 |
},
|
32 |
{
|
33 |
"cell_type": "markdown",
|
34 |
"source": [
|
35 |
+
"# 🦙🎛️ LLaMA-LoRA Tuner\n",
|
36 |
"\n",
|
37 |
"TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
|
38 |
],
|
|
|
54 |
"execution_count": null,
|
55 |
"outputs": []
|
56 |
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"source": [
|
60 |
+
"# @title A small workaround { display-mode: \"form\" }\n",
|
61 |
+
"# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
|
62 |
+
"# @markdown The error will disappear on the next run.\n",
|
63 |
+
"!pip install Pillow==9.3.0\n",
|
64 |
+
"import PIL\n",
|
65 |
+
"major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
|
66 |
+
"version_float = major + minor / 10**len(str(minor))\n",
|
67 |
+
"print(version_float)\n",
|
68 |
+
"if version_float < 9.003:\n",
|
69 |
+
" raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
|
70 |
+
],
|
71 |
+
"metadata": {
|
72 |
+
"id": "XcJ4WO3KhOX1"
|
73 |
+
},
|
74 |
+
"execution_count": null,
|
75 |
+
"outputs": []
|
76 |
+
},
|
77 |
{
|
78 |
"cell_type": "markdown",
|
79 |
"source": [
|
|
|
91 |
"# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
|
92 |
"# @markdown Project settings.\n",
|
93 |
"\n",
|
94 |
+
"# @markdown The URL of the LLaMA-LoRA-Tuner project<br> (default: `https://github.com/zetavg/LLaMA-LoRA-Tuner.git`):\n",
|
95 |
+
"llama_lora_project_url = \"https://github.com/zetavg/LLaMA-LoRA-Tuner.git\" # @param {type:\"string\"}\n",
|
96 |
+
"# @markdown The branch to use for LLaMA-LoRA-Tuner project:\n",
|
97 |
"llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
|
98 |
"\n",
|
99 |
"# # @markdown Forces the local directory to be updated by the remote branch:\n",
|
|
|
116 |
"# @markdown You can customize the location of the stored data here.\n",
|
117 |
"\n",
|
118 |
"# @markdown The folder in Google Drive where Colab Notebook data are stored<br /> **(WARNING: The content of this folder will be modified by this notebook)**:\n",
|
119 |
+
"google_drive_folder = \"Colab Data/LLaMA-LoRA Tuner\" # @param {type:\"string\"}\n",
|
120 |
"# google_drive_colab_data_folder = \"Colab Notebooks/Notebook Data\"\n",
|
121 |
"\n",
|
122 |
"# Where Google Drive will be mounted in the Colab runtime.\n",
|
|
|
239 |
"source": [
|
240 |
"![ ! -d llama_lora ] && git clone -b {llama_lora_project_branch} --filter=tree:0 {llama_lora_project_url} llama_lora\n",
|
241 |
"!cd llama_lora && git add --all && git stash && git fetch origin {llama_lora_project_branch} && git checkout {llama_lora_project_branch} && git reset origin/{llama_lora_project_branch} --hard\n",
|
242 |
+
"![ ! -f llama-lora-requirements-installed ] && cd llama_lora && pip install -r requirements.lock.txt && touch ../llama-lora-requirements-installed"
|
243 |
],
|
244 |
"metadata": {
|
245 |
"id": "JGYz2VDoAzC8"
|
|
|
281 |
"\n",
|
282 |
"# Set Configs\n",
|
283 |
"from llama_lora.llama_lora.globals import Global\n",
|
284 |
+
"Global.default_base_model_name = base_model\n",
|
285 |
"data_dir_realpath = !realpath ./data\n",
|
286 |
"Global.data_dir = data_dir_realpath[0]\n",
|
287 |
"Global.load_8bit = True\n",
|
|
|
289 |
"# Prepare Data Dir\n",
|
290 |
"import os\n",
|
291 |
"from llama_lora.llama_lora.utils.data import init_data_dir\n",
|
292 |
+
"init_data_dir()",
|
293 |
"\n",
|
294 |
"# Load the Base Model\n",
|
295 |
+
"from llama_lora.llama_lora.models import prepare_base_model\n",
|
296 |
+
"prepare_base_model()\n"
|
|
|
297 |
],
|
298 |
"metadata": {
|
299 |
"id": "Yf6g248ylteP"
|
README.md
CHANGED
@@ -20,16 +20,19 @@ git push -f hf-ui-demo hf-ui-demo:main
|
|
20 |
|
21 |
---
|
22 |
|
23 |
-
# 🦙🎛️ LLaMA-LoRA
|
24 |
|
25 |
-
|
|
|
|
|
26 |
|
27 |
Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
|
28 |
|
29 |
|
30 |
## Features
|
31 |
|
32 |
-
|
|
|
|
|
33 |
* Loads and stores data in Google Drive.
|
34 |
* Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/A3kb4VkDWyY"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230272844-09f7a35b-46bf-4101-b15d-4ddf243b8bef.gif" /></a>
|
35 |
* Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/5Db9U8PsaUk"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230277315-9a91d983-1690-4594-9d54-912eda8963ee.gif" /></a>
|
@@ -47,7 +50,7 @@ There are various ways to run this app:
|
|
47 |
|
48 |
### Run On Google Colab
|
49 |
|
50 |
-
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
|
51 |
|
52 |
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
53 |
|
@@ -58,10 +61,10 @@ After approximately 5 minutes of running, you will see the public URL in the out
|
|
58 |
After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
|
59 |
|
60 |
```yaml
|
61 |
-
# llama-lora-
|
62 |
|
63 |
resources:
|
64 |
-
accelerators: A10:1 # 1x NVIDIA A10 GPU
|
65 |
cloud: lambda # Optional; if left out, SkyPilot will automatically pick the cheapest cloud.
|
66 |
|
67 |
file_mounts:
|
@@ -69,27 +72,27 @@ file_mounts:
|
|
69 |
# (to store train datasets trained models)
|
70 |
# See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
|
71 |
/data:
|
72 |
-
name: llama-lora-
|
73 |
-
store:
|
74 |
mode: MOUNT
|
75 |
|
76 |
-
# Clone the LLaMA-LoRA repo and install its dependencies.
|
77 |
setup: |
|
78 |
-
git clone https://github.com/zetavg/LLaMA-LoRA.git
|
79 |
-
cd
|
80 |
cd ..
|
81 |
echo 'Dependencies installed.'
|
82 |
|
83 |
# Start the app.
|
84 |
run: |
|
85 |
echo 'Starting...'
|
86 |
-
python
|
87 |
```
|
88 |
|
89 |
Then launch a cluster to run the task:
|
90 |
|
91 |
```
|
92 |
-
sky launch -c llama-lora-
|
93 |
```
|
94 |
|
95 |
`-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
|
@@ -106,13 +109,13 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina
|
|
106 |
<summary>Prepare environment with conda</summary>
|
107 |
|
108 |
```bash
|
109 |
-
conda create -y -n llama-lora-
|
110 |
-
conda activate llama-lora-
|
111 |
```
|
112 |
</details>
|
113 |
|
114 |
```bash
|
115 |
-
pip install -r requirements.txt
|
116 |
python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
|
117 |
```
|
118 |
|
|
|
20 |
|
21 |
---
|
22 |
|
|
|
23 |
|
24 |
+
# 🦙🎛️ LLaMA-LoRA Tuner
|
25 |
+
|
26 |
+
<a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
27 |
|
28 |
Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
|
29 |
|
30 |
|
31 |
## Features
|
32 |
|
33 |
+
**[See a demo on Hugging Face](https://huggingface.co/spaces/zetavg/LLaMA-LoRA-UI-Demo)** **Only serves UI demonstration. To try training or text generation, [run on Colab](#run-on-google-colab).*
|
34 |
+
|
35 |
+
* **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
|
36 |
* Loads and stores data in Google Drive.
|
37 |
* Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/A3kb4VkDWyY"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230272844-09f7a35b-46bf-4101-b15d-4ddf243b8bef.gif" /></a>
|
38 |
* Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/5Db9U8PsaUk"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230277315-9a91d983-1690-4594-9d54-912eda8963ee.gif" /></a>
|
|
|
50 |
|
51 |
### Run On Google Colab
|
52 |
|
53 |
+
Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
|
54 |
|
55 |
You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
|
56 |
|
|
|
61 |
After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
|
62 |
|
63 |
```yaml
|
64 |
+
# llama-lora-tuner.yaml
|
65 |
|
66 |
resources:
|
67 |
+
accelerators: A10:1 # 1x NVIDIA A10 GPU, about US$ 0.6 / hr on Lambda Cloud.
|
68 |
cloud: lambda # Optional; if left out, SkyPilot will automatically pick the cheapest cloud.
|
69 |
|
70 |
file_mounts:
|
|
|
72 |
# (to store train datasets trained models)
|
73 |
# See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
|
74 |
/data:
|
75 |
+
name: llama-lora-tuner-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
|
76 |
+
store: s3 # Could be either of [s3, gcs]
|
77 |
mode: MOUNT
|
78 |
|
79 |
+
# Clone the LLaMA-LoRA Tuner repo and install its dependencies.
|
80 |
setup: |
|
81 |
+
git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
|
82 |
+
cd llama_lora_tuner && pip install -r requirements.lock.txt
|
83 |
cd ..
|
84 |
echo 'Dependencies installed.'
|
85 |
|
86 |
# Start the app.
|
87 |
run: |
|
88 |
echo 'Starting...'
|
89 |
+
python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
|
90 |
```
|
91 |
|
92 |
Then launch a cluster to run the task:
|
93 |
|
94 |
```
|
95 |
+
sky launch -c llama-lora-tuner llama-lora-tuner.yaml
|
96 |
```
|
97 |
|
98 |
`-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
|
|
|
109 |
<summary>Prepare environment with conda</summary>
|
110 |
|
111 |
```bash
|
112 |
+
conda create -y python=3.8 -n llama-lora-tuner
|
113 |
+
conda activate llama-lora-tuner
|
114 |
```
|
115 |
</details>
|
116 |
|
117 |
```bash
|
118 |
+
pip install -r requirements.lock.txt
|
119 |
python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
|
120 |
```
|
121 |
|
app.py
CHANGED
@@ -16,6 +16,7 @@ def main(
|
|
16 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
17 |
server_name: str = "127.0.0.1",
|
18 |
share: bool = False,
|
|
|
19 |
ui_show_sys_info: bool = True,
|
20 |
ui_dev_mode: bool = False,
|
21 |
):
|
@@ -29,7 +30,7 @@ def main(
|
|
29 |
data_dir
|
30 |
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
31 |
|
32 |
-
Global.
|
33 |
Global.data_dir = os.path.abspath(data_dir)
|
34 |
Global.load_8bit = load_8bit
|
35 |
|
|
|
16 |
# Allows to listen on all interfaces by providing '0.0.0.0'.
|
17 |
server_name: str = "127.0.0.1",
|
18 |
share: bool = False,
|
19 |
+
skip_loading_base_model: bool = False,
|
20 |
ui_show_sys_info: bool = True,
|
21 |
ui_dev_mode: bool = False,
|
22 |
):
|
|
|
30 |
data_dir
|
31 |
), "Please specify a --data_dir, e.g. --data_dir='./data'"
|
32 |
|
33 |
+
Global.default_base_model_name = base_model
|
34 |
Global.data_dir = os.path.abspath(data_dir)
|
35 |
Global.load_8bit = load_8bit
|
36 |
|
llama_lora/globals.py
CHANGED
@@ -6,18 +6,17 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6 |
from numba import cuda
|
7 |
import nvidia_smi
|
8 |
|
|
|
9 |
from .lib.finetune import train
|
10 |
|
11 |
|
12 |
class Global:
|
13 |
version = None
|
14 |
|
15 |
-
base_model: str = ""
|
16 |
data_dir: str = ""
|
17 |
load_8bit: bool = False
|
18 |
|
19 |
-
|
20 |
-
loaded_base_model: Any = None
|
21 |
|
22 |
# Functions
|
23 |
train_fn: Any = train
|
@@ -25,8 +24,15 @@ class Global:
|
|
25 |
# Training Control
|
26 |
should_stop_training = False
|
27 |
|
|
|
|
|
|
|
|
|
28 |
# Model related
|
29 |
-
|
|
|
|
|
|
|
30 |
|
31 |
# GPU Info
|
32 |
gpu_cc = None # GPU compute capability
|
@@ -35,7 +41,7 @@ class Global:
|
|
35 |
gpu_total_memory = None
|
36 |
|
37 |
# UI related
|
38 |
-
ui_title: str = "LLaMA-LoRA"
|
39 |
ui_emoji: str = "🦙🎛️"
|
40 |
ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
|
41 |
ui_show_sys_info: bool = True
|
|
|
6 |
from numba import cuda
|
7 |
import nvidia_smi
|
8 |
|
9 |
+
from .utils.lru_cache import LRUCache
|
10 |
from .lib.finetune import train
|
11 |
|
12 |
|
13 |
class Global:
|
14 |
version = None
|
15 |
|
|
|
16 |
data_dir: str = ""
|
17 |
load_8bit: bool = False
|
18 |
|
19 |
+
default_base_model_name: str = ""
|
|
|
20 |
|
21 |
# Functions
|
22 |
train_fn: Any = train
|
|
|
24 |
# Training Control
|
25 |
should_stop_training = False
|
26 |
|
27 |
+
# Generation Control
|
28 |
+
should_stop_generating = False
|
29 |
+
generation_force_stopped_at = None
|
30 |
+
|
31 |
# Model related
|
32 |
+
loaded_models = LRUCache(1)
|
33 |
+
loaded_tokenizers = LRUCache(1)
|
34 |
+
new_base_model_that_is_ready_to_be_used = None
|
35 |
+
name_of_new_base_model_that_is_ready_to_be_used = None
|
36 |
|
37 |
# GPU Info
|
38 |
gpu_cc = None # GPU compute capability
|
|
|
41 |
gpu_total_memory = None
|
42 |
|
43 |
# UI related
|
44 |
+
ui_title: str = "LLaMA-LoRA Tuner"
|
45 |
ui_emoji: str = "🦙🎛️"
|
46 |
ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
|
47 |
ui_show_sys_info: bool = True
|
llama_lora/lib/finetune.py
CHANGED
@@ -2,6 +2,8 @@ import os
|
|
2 |
import sys
|
3 |
from typing import Any, List
|
4 |
|
|
|
|
|
5 |
import fire
|
6 |
import torch
|
7 |
import transformers
|
@@ -47,6 +49,10 @@ def train(
|
|
47 |
# logging
|
48 |
callbacks: List[Any] = []
|
49 |
):
|
|
|
|
|
|
|
|
|
50 |
device_map = "auto"
|
51 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
52 |
ddp = world_size != 1
|
@@ -202,6 +208,12 @@ def train(
|
|
202 |
),
|
203 |
callbacks=callbacks,
|
204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
model.config.use_cache = False
|
206 |
|
207 |
old_state_dict = model.state_dict
|
@@ -214,9 +226,16 @@ def train(
|
|
214 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
215 |
model = torch.compile(model)
|
216 |
|
217 |
-
|
218 |
|
219 |
model.save_pretrained(output_dir)
|
220 |
print(f"Model saved to {output_dir}.")
|
221 |
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import sys
|
3 |
from typing import Any, List
|
4 |
|
5 |
+
import json
|
6 |
+
|
7 |
import fire
|
8 |
import torch
|
9 |
import transformers
|
|
|
49 |
# logging
|
50 |
callbacks: List[Any] = []
|
51 |
):
|
52 |
+
if os.path.exists(output_dir):
|
53 |
+
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
54 |
+
raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
|
55 |
+
|
56 |
device_map = "auto"
|
57 |
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
58 |
ddp = world_size != 1
|
|
|
208 |
),
|
209 |
callbacks=callbacks,
|
210 |
)
|
211 |
+
|
212 |
+
if not os.path.exists(output_dir):
|
213 |
+
os.makedirs(output_dir)
|
214 |
+
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
215 |
+
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
216 |
+
|
217 |
model.config.use_cache = False
|
218 |
|
219 |
old_state_dict = model.state_dict
|
|
|
226 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
227 |
model = torch.compile(model)
|
228 |
|
229 |
+
train_output = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
230 |
|
231 |
model.save_pretrained(output_dir)
|
232 |
print(f"Model saved to {output_dir}.")
|
233 |
|
234 |
+
with open(os.path.join(output_dir, "trainer_log_history.jsonl"), 'w') as trainer_log_history_jsonl_file:
|
235 |
+
trainer_log_history = "\n".join([json.dumps(line) for line in trainer.state.log_history])
|
236 |
+
trainer_log_history_jsonl_file.write(trainer_log_history)
|
237 |
+
|
238 |
+
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
239 |
+
json.dump(train_output, train_output_json_file, indent=2)
|
240 |
+
|
241 |
+
return train_output
|
llama_lora/models.py
CHANGED
@@ -3,9 +3,8 @@ import sys
|
|
3 |
import gc
|
4 |
|
5 |
import torch
|
6 |
-
import
|
7 |
from peft import PeftModel
|
8 |
-
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
9 |
|
10 |
from .globals import Global
|
11 |
|
@@ -23,84 +22,140 @@ def get_device():
|
|
23 |
pass
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
def get_base_model():
|
30 |
-
load_base_model()
|
31 |
-
return Global.loaded_base_model
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
Global.model_has_been_used = True
|
36 |
|
37 |
if device == "cuda":
|
38 |
-
model =
|
39 |
-
|
40 |
-
|
41 |
torch_dtype=torch.float16,
|
42 |
-
device_map=
|
|
|
|
|
43 |
)
|
44 |
elif device == "mps":
|
45 |
-
model =
|
46 |
-
|
47 |
-
lora_weights,
|
48 |
device_map={"": device},
|
49 |
torch_dtype=torch.float16,
|
50 |
)
|
51 |
else:
|
52 |
-
model =
|
53 |
-
|
54 |
-
lora_weights,
|
55 |
-
device_map={"": device},
|
56 |
)
|
57 |
|
58 |
-
model.config.pad_token_id = get_tokenizer().pad_token_id = 0
|
59 |
model.config.bos_token_id = 1
|
60 |
model.config.eos_token_id = 2
|
61 |
|
62 |
-
if not Global.load_8bit:
|
63 |
-
model.half() # seems to fix bugs for some users.
|
64 |
-
|
65 |
-
model.eval()
|
66 |
-
if torch.__version__ >= "2" and sys.platform != "win32":
|
67 |
-
model = torch.compile(model)
|
68 |
return model
|
69 |
|
70 |
|
71 |
-
def get_tokenizer():
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
|
|
75 |
|
76 |
-
|
|
|
|
|
|
|
77 |
if Global.ui_dev_mode:
|
78 |
return
|
79 |
|
80 |
-
if
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
if device == "cuda":
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
torch_dtype=torch.float16,
|
90 |
-
#
|
91 |
-
device_map={'': 0},
|
92 |
)
|
93 |
elif device == "mps":
|
94 |
-
|
95 |
-
|
|
|
96 |
device_map={"": device},
|
97 |
torch_dtype=torch.float16,
|
98 |
)
|
99 |
else:
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def clear_cache():
|
106 |
gc.collect()
|
@@ -111,17 +166,6 @@ def clear_cache():
|
|
111 |
|
112 |
|
113 |
def unload_models():
|
114 |
-
|
115 |
-
Global.
|
116 |
-
|
117 |
-
del Global.loaded_tokenizer
|
118 |
-
Global.loaded_tokenizer = None
|
119 |
-
|
120 |
clear_cache()
|
121 |
-
|
122 |
-
Global.model_has_been_used = False
|
123 |
-
|
124 |
-
|
125 |
-
def unload_models_if_already_used():
|
126 |
-
if Global.model_has_been_used:
|
127 |
-
unload_models()
|
|
|
3 |
import gc
|
4 |
|
5 |
import torch
|
6 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
7 |
from peft import PeftModel
|
|
|
8 |
|
9 |
from .globals import Global
|
10 |
|
|
|
22 |
pass
|
23 |
|
24 |
|
25 |
+
def get_new_base_model(base_model_name):
|
26 |
+
if Global.ui_dev_mode:
|
27 |
+
return
|
|
|
|
|
|
|
28 |
|
29 |
+
if Global.new_base_model_that_is_ready_to_be_used:
|
30 |
+
if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
|
31 |
+
model = Global.new_base_model_that_is_ready_to_be_used
|
32 |
+
Global.new_base_model_that_is_ready_to_be_used = None
|
33 |
+
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
34 |
+
return model
|
35 |
+
else:
|
36 |
+
Global.new_base_model_that_is_ready_to_be_used = None
|
37 |
+
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
38 |
+
clear_cache()
|
39 |
|
40 |
+
device = get_device()
|
|
|
41 |
|
42 |
if device == "cuda":
|
43 |
+
model = LlamaForCausalLM.from_pretrained(
|
44 |
+
base_model_name,
|
45 |
+
load_in_8bit=Global.load_8bit,
|
46 |
torch_dtype=torch.float16,
|
47 |
+
# device_map="auto",
|
48 |
+
# ? https://github.com/tloen/alpaca-lora/issues/21
|
49 |
+
device_map={'': 0},
|
50 |
)
|
51 |
elif device == "mps":
|
52 |
+
model = LlamaForCausalLM.from_pretrained(
|
53 |
+
base_model_name,
|
|
|
54 |
device_map={"": device},
|
55 |
torch_dtype=torch.float16,
|
56 |
)
|
57 |
else:
|
58 |
+
model = LlamaForCausalLM.from_pretrained(
|
59 |
+
base_model_name, device_map={"": device}, low_cpu_mem_usage=True
|
|
|
|
|
60 |
)
|
61 |
|
62 |
+
model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
|
63 |
model.config.bos_token_id = 1
|
64 |
model.config.eos_token_id = 2
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
return model
|
67 |
|
68 |
|
69 |
+
def get_tokenizer(base_model_name):
|
70 |
+
if Global.ui_dev_mode:
|
71 |
+
return
|
72 |
+
|
73 |
+
loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
|
74 |
+
if loaded_tokenizer:
|
75 |
+
return loaded_tokenizer
|
76 |
+
|
77 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
|
78 |
+
Global.loaded_tokenizers.set(base_model_name, tokenizer)
|
79 |
|
80 |
+
return tokenizer
|
81 |
|
82 |
+
|
83 |
+
def get_model(
|
84 |
+
base_model_name,
|
85 |
+
peft_model_name=None):
|
86 |
if Global.ui_dev_mode:
|
87 |
return
|
88 |
|
89 |
+
if peft_model_name == "None":
|
90 |
+
peft_model_name = None
|
91 |
+
|
92 |
+
model_key = base_model_name
|
93 |
+
if peft_model_name:
|
94 |
+
model_key = f"{base_model_name}//{peft_model_name}"
|
95 |
+
|
96 |
+
loaded_model = Global.loaded_models.get(model_key)
|
97 |
+
if loaded_model:
|
98 |
+
return loaded_model
|
99 |
+
|
100 |
+
peft_model_name_or_path = peft_model_name
|
101 |
+
|
102 |
+
lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
|
103 |
+
possible_lora_model_path = os.path.join(
|
104 |
+
lora_models_directory_path, peft_model_name)
|
105 |
+
if os.path.isdir(possible_lora_model_path):
|
106 |
+
peft_model_name_or_path = possible_lora_model_path
|
107 |
+
|
108 |
+
Global.loaded_models.prepare_to_set()
|
109 |
+
clear_cache()
|
110 |
+
|
111 |
+
model = get_new_base_model(base_model_name)
|
112 |
+
|
113 |
+
if peft_model_name:
|
114 |
+
device = get_device()
|
115 |
+
|
116 |
if device == "cuda":
|
117 |
+
model = PeftModel.from_pretrained(
|
118 |
+
model,
|
119 |
+
peft_model_name_or_path,
|
120 |
torch_dtype=torch.float16,
|
121 |
+
# ? https://github.com/tloen/alpaca-lora/issues/21
|
122 |
+
device_map={'': 0},
|
123 |
)
|
124 |
elif device == "mps":
|
125 |
+
model = PeftModel.from_pretrained(
|
126 |
+
model,
|
127 |
+
peft_model_name_or_path,
|
128 |
device_map={"": device},
|
129 |
torch_dtype=torch.float16,
|
130 |
)
|
131 |
else:
|
132 |
+
model = PeftModel.from_pretrained(
|
133 |
+
model,
|
134 |
+
peft_model_name_or_path,
|
135 |
+
device_map={"": device},
|
136 |
)
|
137 |
|
138 |
+
model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
|
139 |
+
model.config.bos_token_id = 1
|
140 |
+
model.config.eos_token_id = 2
|
141 |
+
|
142 |
+
if not Global.load_8bit:
|
143 |
+
model.half() # seems to fix bugs for some users.
|
144 |
+
|
145 |
+
model.eval()
|
146 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
147 |
+
model = torch.compile(model)
|
148 |
+
|
149 |
+
Global.loaded_models.set(model_key, model)
|
150 |
+
clear_cache()
|
151 |
+
|
152 |
+
return model
|
153 |
+
|
154 |
+
|
155 |
+
def prepare_base_model(base_model_name=Global.default_base_model_name):
|
156 |
+
Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(base_model_name)
|
157 |
+
Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
|
158 |
+
|
159 |
|
160 |
def clear_cache():
|
161 |
gc.collect()
|
|
|
166 |
|
167 |
|
168 |
def unload_models():
|
169 |
+
Global.loaded_models.clear()
|
170 |
+
Global.loaded_tokenizers.clear()
|
|
|
|
|
|
|
|
|
171 |
clear_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -10,8 +10,8 @@ from transformers import TrainerCallback
|
|
10 |
|
11 |
from ..globals import Global
|
12 |
from ..models import (
|
13 |
-
|
14 |
-
clear_cache,
|
15 |
from ..utils.data import (
|
16 |
get_available_template_names,
|
17 |
get_available_dataset_names,
|
@@ -258,22 +258,30 @@ def do_train(
|
|
258 |
dataset_plain_text_data_separator,
|
259 |
# Training Options
|
260 |
max_seq_length,
|
|
|
261 |
micro_batch_size,
|
262 |
gradient_accumulation_steps,
|
263 |
epochs,
|
264 |
learning_rate,
|
|
|
265 |
lora_r,
|
266 |
lora_alpha,
|
267 |
lora_dropout,
|
|
|
268 |
model_name,
|
269 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
270 |
):
|
271 |
try:
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
prompter = Prompter(template)
|
279 |
variable_names = prompter.get_variable_names()
|
@@ -319,6 +327,7 @@ def do_train(
|
|
319 |
data = process_json_dataset(data)
|
320 |
|
321 |
data_count = len(data)
|
|
|
322 |
|
323 |
train_data = [
|
324 |
{
|
@@ -356,13 +365,16 @@ def do_train(
|
|
356 |
|
357 |
Train options: {json.dumps({
|
358 |
'max_seq_length': max_seq_length,
|
|
|
359 |
'micro_batch_size': micro_batch_size,
|
360 |
'gradient_accumulation_steps': gradient_accumulation_steps,
|
361 |
'epochs': epochs,
|
362 |
'learning_rate': learning_rate,
|
|
|
363 |
'lora_r': lora_r,
|
364 |
'lora_alpha': lora_alpha,
|
365 |
'lora_dropout': lora_dropout,
|
|
|
366 |
'model_name': model_name,
|
367 |
}, indent=2)}
|
368 |
|
@@ -373,6 +385,9 @@ Train data (first 10):
|
|
373 |
time.sleep(2)
|
374 |
return message
|
375 |
|
|
|
|
|
|
|
376 |
log_history = []
|
377 |
|
378 |
class UiTrainerCallback(TrainerCallback):
|
@@ -409,21 +424,51 @@ Train data (first 10):
|
|
409 |
|
410 |
Global.should_stop_training = False
|
411 |
|
412 |
-
|
413 |
-
|
414 |
-
clear_cache()
|
415 |
-
|
416 |
-
base_model = get_base_model()
|
417 |
-
tokenizer = get_tokenizer()
|
418 |
|
419 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
420 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
421 |
|
422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
base_model, # base_model
|
424 |
tokenizer, # tokenizer
|
425 |
-
|
426 |
-
model_name), # output_dir
|
427 |
train_data,
|
428 |
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
429 |
micro_batch_size, # micro_batch_size
|
@@ -431,12 +476,12 @@ Train data (first 10):
|
|
431 |
epochs, # num_epochs
|
432 |
learning_rate, # learning_rate
|
433 |
max_seq_length, # cutoff_len
|
434 |
-
|
435 |
lora_r, # lora_r
|
436 |
lora_alpha, # lora_alpha
|
437 |
lora_dropout, # lora_dropout
|
438 |
-
|
439 |
-
|
440 |
False, # group_by_length
|
441 |
None, # resume_from_checkpoint
|
442 |
training_callbacks # callbacks
|
@@ -445,12 +490,17 @@ Train data (first 10):
|
|
445 |
logs_str = "\n".join([json.dumps(log)
|
446 |
for log in log_history]) or "None"
|
447 |
|
448 |
-
result_message = f"Training ended:\n{str(
|
449 |
print(result_message)
|
|
|
|
|
|
|
|
|
|
|
450 |
return result_message
|
451 |
|
452 |
except Exception as e:
|
453 |
-
raise gr.Error(e)
|
454 |
|
455 |
|
456 |
def do_abort_training():
|
@@ -595,11 +645,20 @@ def finetune_ui():
|
|
595 |
)
|
596 |
)
|
597 |
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
|
604 |
with gr.Row():
|
605 |
micro_batch_size_default_value = 1
|
@@ -625,7 +684,7 @@ def finetune_ui():
|
|
625 |
)
|
626 |
|
627 |
epochs = gr.Slider(
|
628 |
-
minimum=1, maximum=100, step=1, value=
|
629 |
label="Epochs",
|
630 |
info="The number of times to iterate over the entire training dataset. A larger number of epochs may improve model performance but also increase the risk of overfitting.")
|
631 |
|
@@ -635,6 +694,12 @@ def finetune_ui():
|
|
635 |
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
|
636 |
)
|
637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
638 |
with gr.Column():
|
639 |
lora_r = gr.Slider(
|
640 |
minimum=1, maximum=16, step=1, value=8,
|
@@ -654,6 +719,12 @@ def finetune_ui():
|
|
654 |
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
|
655 |
)
|
656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
657 |
with gr.Column():
|
658 |
model_name = gr.Textbox(
|
659 |
lines=1, label="LoRA Model Name", value=random_name,
|
@@ -675,25 +746,28 @@ def finetune_ui():
|
|
675 |
elem_id="finetune_confirm_stop_btn"
|
676 |
)
|
677 |
|
678 |
-
|
679 |
-
"Training
|
680 |
-
label="
|
681 |
elem_id="finetune_training_status")
|
682 |
|
683 |
train_progress = train_btn.click(
|
684 |
fn=do_train,
|
685 |
inputs=(dataset_inputs + [
|
686 |
max_seq_length,
|
|
|
687 |
micro_batch_size,
|
688 |
gradient_accumulation_steps,
|
689 |
epochs,
|
690 |
learning_rate,
|
|
|
691 |
lora_r,
|
692 |
lora_alpha,
|
693 |
lora_dropout,
|
|
|
694 |
model_name
|
695 |
]),
|
696 |
-
outputs=
|
697 |
)
|
698 |
|
699 |
# controlled by JS, shows the confirm_abort_button
|
@@ -811,6 +885,12 @@ def finetune_ui():
|
|
811 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
812 |
'none';
|
813 |
}, 5000);
|
|
|
|
|
|
|
|
|
|
|
|
|
814 |
document.getElementById('finetune_stop_btn').style.display = 'none';
|
815 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
816 |
'block';
|
|
|
10 |
|
11 |
from ..globals import Global
|
12 |
from ..models import (
|
13 |
+
get_new_base_model, get_tokenizer,
|
14 |
+
clear_cache, unload_models)
|
15 |
from ..utils.data import (
|
16 |
get_available_template_names,
|
17 |
get_available_dataset_names,
|
|
|
258 |
dataset_plain_text_data_separator,
|
259 |
# Training Options
|
260 |
max_seq_length,
|
261 |
+
evaluate_data_percentage,
|
262 |
micro_batch_size,
|
263 |
gradient_accumulation_steps,
|
264 |
epochs,
|
265 |
learning_rate,
|
266 |
+
train_on_inputs,
|
267 |
lora_r,
|
268 |
lora_alpha,
|
269 |
lora_dropout,
|
270 |
+
lora_target_modules,
|
271 |
model_name,
|
272 |
progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
|
273 |
):
|
274 |
try:
|
275 |
+
base_model_name = Global.default_base_model_name
|
276 |
+
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
277 |
+
if os.path.exists(output_dir):
|
278 |
+
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
279 |
+
raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
|
280 |
+
|
281 |
+
if not should_training_progress_track_tqdm:
|
282 |
+
progress(0, desc="Preparing train data...")
|
283 |
+
|
284 |
+
unload_models() # Need RAM for training
|
285 |
|
286 |
prompter = Prompter(template)
|
287 |
variable_names = prompter.get_variable_names()
|
|
|
327 |
data = process_json_dataset(data)
|
328 |
|
329 |
data_count = len(data)
|
330 |
+
evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
|
331 |
|
332 |
train_data = [
|
333 |
{
|
|
|
365 |
|
366 |
Train options: {json.dumps({
|
367 |
'max_seq_length': max_seq_length,
|
368 |
+
'val_set_size': evaluate_data_count,
|
369 |
'micro_batch_size': micro_batch_size,
|
370 |
'gradient_accumulation_steps': gradient_accumulation_steps,
|
371 |
'epochs': epochs,
|
372 |
'learning_rate': learning_rate,
|
373 |
+
'train_on_inputs': train_on_inputs,
|
374 |
'lora_r': lora_r,
|
375 |
'lora_alpha': lora_alpha,
|
376 |
'lora_dropout': lora_dropout,
|
377 |
+
'lora_target_modules': lora_target_modules,
|
378 |
'model_name': model_name,
|
379 |
}, indent=2)}
|
380 |
|
|
|
385 |
time.sleep(2)
|
386 |
return message
|
387 |
|
388 |
+
if not should_training_progress_track_tqdm:
|
389 |
+
progress(0, desc="Preparing model for training...")
|
390 |
+
|
391 |
log_history = []
|
392 |
|
393 |
class UiTrainerCallback(TrainerCallback):
|
|
|
424 |
|
425 |
Global.should_stop_training = False
|
426 |
|
427 |
+
base_model = get_new_base_model(base_model_name)
|
428 |
+
tokenizer = get_tokenizer(base_model_name)
|
|
|
|
|
|
|
|
|
429 |
|
430 |
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
431 |
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
432 |
|
433 |
+
if not os.path.exists(output_dir):
|
434 |
+
os.makedirs(output_dir)
|
435 |
+
|
436 |
+
with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
|
437 |
+
dataset_name = "N/A (from text input)"
|
438 |
+
if load_dataset_from == "Data Dir":
|
439 |
+
dataset_name = dataset_from_data_dir
|
440 |
+
|
441 |
+
info = {
|
442 |
+
'base_model': base_model_name,
|
443 |
+
'prompt_template': template,
|
444 |
+
'dataset_name': dataset_name,
|
445 |
+
'dataset_rows': len(train_data),
|
446 |
+
'timestamp': time.time(),
|
447 |
+
|
448 |
+
'max_seq_length': max_seq_length,
|
449 |
+
'train_on_inputs': train_on_inputs,
|
450 |
+
|
451 |
+
'micro_batch_size': micro_batch_size,
|
452 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
453 |
+
'epochs': epochs,
|
454 |
+
'learning_rate': learning_rate,
|
455 |
+
|
456 |
+
'evaluate_data_percentage': evaluate_data_percentage,
|
457 |
+
|
458 |
+
'lora_r': lora_r,
|
459 |
+
'lora_alpha': lora_alpha,
|
460 |
+
'lora_dropout': lora_dropout,
|
461 |
+
'lora_target_modules': lora_target_modules,
|
462 |
+
}
|
463 |
+
json.dump(info, info_json_file, indent=2)
|
464 |
+
|
465 |
+
if not should_training_progress_track_tqdm:
|
466 |
+
progress(0, desc="Train starting...")
|
467 |
+
|
468 |
+
train_output = Global.train_fn(
|
469 |
base_model, # base_model
|
470 |
tokenizer, # tokenizer
|
471 |
+
output_dir, # output_dir
|
|
|
472 |
train_data,
|
473 |
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
474 |
micro_batch_size, # micro_batch_size
|
|
|
476 |
epochs, # num_epochs
|
477 |
learning_rate, # learning_rate
|
478 |
max_seq_length, # cutoff_len
|
479 |
+
evaluate_data_count, # val_set_size
|
480 |
lora_r, # lora_r
|
481 |
lora_alpha, # lora_alpha
|
482 |
lora_dropout, # lora_dropout
|
483 |
+
lora_target_modules, # lora_target_modules
|
484 |
+
train_on_inputs, # train_on_inputs
|
485 |
False, # group_by_length
|
486 |
None, # resume_from_checkpoint
|
487 |
training_callbacks # callbacks
|
|
|
490 |
logs_str = "\n".join([json.dumps(log)
|
491 |
for log in log_history]) or "None"
|
492 |
|
493 |
+
result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
|
494 |
print(result_message)
|
495 |
+
|
496 |
+
del base_model
|
497 |
+
del tokenizer
|
498 |
+
clear_cache()
|
499 |
+
|
500 |
return result_message
|
501 |
|
502 |
except Exception as e:
|
503 |
+
raise gr.Error(f"{e} (To dismiss this error, click the 'Abort' button)")
|
504 |
|
505 |
|
506 |
def do_abort_training():
|
|
|
645 |
)
|
646 |
)
|
647 |
|
648 |
+
with gr.Row():
|
649 |
+
max_seq_length = gr.Slider(
|
650 |
+
minimum=1, maximum=4096, value=512,
|
651 |
+
label="Max Sequence Length",
|
652 |
+
info="The maximum length of each sample text sequence. Sequences longer than this will be truncated.",
|
653 |
+
elem_id="finetune_max_seq_length"
|
654 |
+
)
|
655 |
+
|
656 |
+
train_on_inputs = gr.Checkbox(
|
657 |
+
label="Train on Inputs",
|
658 |
+
value=True,
|
659 |
+
info="If not enabled, inputs will be masked out in loss.",
|
660 |
+
elem_id="finetune_train_on_inputs"
|
661 |
+
)
|
662 |
|
663 |
with gr.Row():
|
664 |
micro_batch_size_default_value = 1
|
|
|
684 |
)
|
685 |
|
686 |
epochs = gr.Slider(
|
687 |
+
minimum=1, maximum=100, step=1, value=10,
|
688 |
label="Epochs",
|
689 |
info="The number of times to iterate over the entire training dataset. A larger number of epochs may improve model performance but also increase the risk of overfitting.")
|
690 |
|
|
|
694 |
info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
|
695 |
)
|
696 |
|
697 |
+
evaluate_data_percentage = gr.Slider(
|
698 |
+
minimum=0, maximum=0.5, step=0.001, value=0.03,
|
699 |
+
label="Evaluation Data Percentage",
|
700 |
+
info="The percentage of data to be used for evaluation. This percentage of data will not be used for training and will be used to assess the performance of the model during the process."
|
701 |
+
)
|
702 |
+
|
703 |
with gr.Column():
|
704 |
lora_r = gr.Slider(
|
705 |
minimum=1, maximum=16, step=1, value=8,
|
|
|
719 |
info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
|
720 |
)
|
721 |
|
722 |
+
lora_target_modules = gr.CheckboxGroup(
|
723 |
+
label="LoRA Target Modules",
|
724 |
+
choices=["q_proj", "k_proj", "v_proj", "o_proj"],
|
725 |
+
value=["q_proj", "v_proj"],
|
726 |
+
)
|
727 |
+
|
728 |
with gr.Column():
|
729 |
model_name = gr.Textbox(
|
730 |
lines=1, label="LoRA Model Name", value=random_name,
|
|
|
746 |
elem_id="finetune_confirm_stop_btn"
|
747 |
)
|
748 |
|
749 |
+
train_output = gr.Text(
|
750 |
+
"Training results will be shown here.",
|
751 |
+
label="Train Output",
|
752 |
elem_id="finetune_training_status")
|
753 |
|
754 |
train_progress = train_btn.click(
|
755 |
fn=do_train,
|
756 |
inputs=(dataset_inputs + [
|
757 |
max_seq_length,
|
758 |
+
evaluate_data_percentage,
|
759 |
micro_batch_size,
|
760 |
gradient_accumulation_steps,
|
761 |
epochs,
|
762 |
learning_rate,
|
763 |
+
train_on_inputs,
|
764 |
lora_r,
|
765 |
lora_alpha,
|
766 |
lora_dropout,
|
767 |
+
lora_target_modules,
|
768 |
model_name
|
769 |
]),
|
770 |
+
outputs=train_output
|
771 |
)
|
772 |
|
773 |
# controlled by JS, shows the confirm_abort_button
|
|
|
885 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
886 |
'none';
|
887 |
}, 5000);
|
888 |
+
document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
|
889 |
+
'none';
|
890 |
+
setTimeout(function () {
|
891 |
+
document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
|
892 |
+
'inherit';
|
893 |
+
}, 300);
|
894 |
document.getElementById('finetune_stop_btn').style.display = 'none';
|
895 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
896 |
'block';
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -7,17 +7,30 @@ import transformers
|
|
7 |
from transformers import GenerationConfig
|
8 |
|
9 |
from ..globals import Global
|
10 |
-
from ..models import
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
-
|
15 |
from ..utils.prompter import Prompter
|
16 |
from ..utils.callbacks import Iteratorize, Stream
|
17 |
|
18 |
device = get_device()
|
19 |
|
20 |
default_show_raw = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
def do_inference(
|
@@ -35,20 +48,25 @@ def do_inference(
|
|
35 |
show_raw=False,
|
36 |
progress=gr.Progress(track_tqdm=True),
|
37 |
):
|
|
|
|
|
38 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
40 |
variable_4, variable_5, variable_6, variable_7]
|
41 |
prompter = Prompter(prompt_template)
|
42 |
prompt = prompter.generate_prompt(variables)
|
43 |
|
44 |
-
if lora_model_name is not None and "/" not in lora_model_name and lora_model_name != "None":
|
45 |
-
path_of_available_lora_model = get_path_of_available_lora_model(
|
46 |
-
lora_model_name)
|
47 |
-
if path_of_available_lora_model:
|
48 |
-
lora_model_name = path_of_available_lora_model
|
49 |
-
|
50 |
if Global.ui_dev_mode:
|
51 |
-
message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {
|
52 |
print(message)
|
53 |
|
54 |
if stream_output:
|
@@ -66,18 +84,24 @@ def do_inference(
|
|
66 |
yield out
|
67 |
|
68 |
for partial_sentence in word_generator(message):
|
69 |
-
yield
|
|
|
|
|
|
|
|
|
|
|
70 |
time.sleep(0.05)
|
71 |
|
72 |
return
|
73 |
time.sleep(1)
|
74 |
-
yield
|
|
|
|
|
|
|
75 |
return
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
model = get_model_with_lora(lora_model_name)
|
80 |
-
tokenizer = get_tokenizer()
|
81 |
|
82 |
inputs = tokenizer(prompt, return_tensors="pt")
|
83 |
input_ids = inputs["input_ids"].to(device)
|
@@ -97,6 +121,19 @@ def do_inference(
|
|
97 |
"max_new_tokens": max_new_tokens,
|
98 |
}
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
if stream_output:
|
101 |
# Stream the reply 1 token at a time.
|
102 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
@@ -128,29 +165,60 @@ def do_inference(
|
|
128 |
raw_output = None
|
129 |
if show_raw:
|
130 |
raw_output = str(output)
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
return # early return for stream_output
|
133 |
|
134 |
# Without streaming
|
135 |
with torch.no_grad():
|
136 |
-
generation_output = model.generate(
|
137 |
-
input_ids=input_ids,
|
138 |
-
generation_config=generation_config,
|
139 |
-
return_dict_in_generate=True,
|
140 |
-
output_scores=True,
|
141 |
-
max_new_tokens=max_new_tokens,
|
142 |
-
)
|
143 |
s = generation_output.sequences[0]
|
144 |
output = tokenizer.decode(s)
|
145 |
raw_output = None
|
146 |
if show_raw:
|
147 |
raw_output = str(s)
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
except Exception as e:
|
151 |
raise gr.Error(e)
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
154 |
def reload_selections(current_lora_model, current_prompt_template):
|
155 |
available_template_names = get_available_template_names()
|
156 |
available_template_names_with_none = available_template_names + ["None"]
|
@@ -172,7 +240,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
172 |
gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
|
173 |
|
174 |
|
175 |
-
def handle_prompt_template_change(prompt_template):
|
176 |
prompter = Prompter(prompt_template)
|
177 |
var_names = prompter.get_variable_names()
|
178 |
human_var_names = [' '.join(word.capitalize()
|
@@ -182,7 +250,36 @@ def handle_prompt_template_change(prompt_template):
|
|
182 |
while len(gr_updates) < 8:
|
183 |
gr_updates.append(gr.Textbox.update(
|
184 |
label="Not Used", visible=False))
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
def update_prompt_preview(prompt_template,
|
@@ -200,12 +297,15 @@ def inference_ui():
|
|
200 |
|
201 |
with gr.Blocks() as inference_ui_blocks:
|
202 |
with gr.Row():
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
209 |
prompt_template = gr.Dropdown(
|
210 |
label="Prompt Template",
|
211 |
elem_id="inference_prompt_template",
|
@@ -278,7 +378,7 @@ def inference_ui():
|
|
278 |
)
|
279 |
|
280 |
num_beams = gr.Slider(
|
281 |
-
minimum=1, maximum=
|
282 |
label="Beams",
|
283 |
elem_id="inference_beams"
|
284 |
)
|
@@ -318,7 +418,7 @@ def inference_ui():
|
|
318 |
with gr.Column(elem_id="inference_output_group_container"):
|
319 |
with gr.Column(elem_id="inference_output_group"):
|
320 |
inference_output = gr.Textbox(
|
321 |
-
lines=
|
322 |
inference_output.style(show_copy_button=True)
|
323 |
with gr.Accordion(
|
324 |
"Raw Output",
|
@@ -346,11 +446,25 @@ def inference_ui():
|
|
346 |
)
|
347 |
things_that_might_timeout.append(reload_selections_event)
|
348 |
|
349 |
-
prompt_template_change_event = prompt_template.change(
|
350 |
-
|
|
|
|
|
|
|
|
|
351 |
things_that_might_timeout.append(prompt_template_change_event)
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
generate_event = generate_btn.click(
|
|
|
|
|
|
|
|
|
354 |
fn=do_inference,
|
355 |
inputs=[
|
356 |
lora_model,
|
@@ -369,8 +483,12 @@ def inference_ui():
|
|
369 |
outputs=[inference_output, inference_raw_output],
|
370 |
api_name="inference"
|
371 |
)
|
372 |
-
stop_btn.click(
|
373 |
-
|
|
|
|
|
|
|
|
|
374 |
|
375 |
update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
|
376 |
variable_0, variable_1, variable_2, variable_3,
|
@@ -543,9 +661,15 @@ def inference_ui():
|
|
543 |
return function (...args) {
|
544 |
const context = this;
|
545 |
clearTimeout(timeout);
|
546 |
-
|
|
|
|
|
|
|
|
|
|
|
547 |
func.apply(context, args);
|
548 |
-
}
|
|
|
549 |
};
|
550 |
}
|
551 |
|
@@ -580,5 +704,27 @@ def inference_ui():
|
|
580 |
});
|
581 |
}
|
582 |
}, 100);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
}
|
584 |
""")
|
|
|
7 |
from transformers import GenerationConfig
|
8 |
|
9 |
from ..globals import Global
|
10 |
+
from ..models import get_model, get_tokenizer, get_device
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
+
get_info_of_available_lora_model)
|
15 |
from ..utils.prompter import Prompter
|
16 |
from ..utils.callbacks import Iteratorize, Stream
|
17 |
|
18 |
device = get_device()
|
19 |
|
20 |
default_show_raw = True
|
21 |
+
inference_output_lines = 12
|
22 |
+
|
23 |
+
|
24 |
+
def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
|
25 |
+
base_model_name = Global.default_base_model_name
|
26 |
+
|
27 |
+
try:
|
28 |
+
get_tokenizer(base_model_name)
|
29 |
+
get_model(base_model_name, lora_model_name)
|
30 |
+
return ("", "")
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
raise gr.Error(e)
|
34 |
|
35 |
|
36 |
def do_inference(
|
|
|
48 |
show_raw=False,
|
49 |
progress=gr.Progress(track_tqdm=True),
|
50 |
):
|
51 |
+
base_model_name = Global.default_base_model_name
|
52 |
+
|
53 |
try:
|
54 |
+
if Global.generation_force_stopped_at is not None:
|
55 |
+
required_elapsed_time_after_forced_stop = 1
|
56 |
+
current_unix_time = time.time()
|
57 |
+
remaining_time = required_elapsed_time_after_forced_stop - \
|
58 |
+
(current_unix_time - Global.generation_force_stopped_at)
|
59 |
+
if remaining_time > 0:
|
60 |
+
time.sleep(remaining_time)
|
61 |
+
Global.generation_force_stopped_at = None
|
62 |
+
|
63 |
variables = [variable_0, variable_1, variable_2, variable_3,
|
64 |
variable_4, variable_5, variable_6, variable_7]
|
65 |
prompter = Prompter(prompt_template)
|
66 |
prompt = prompter.generate_prompt(variables)
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
if Global.ui_dev_mode:
|
69 |
+
message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
|
70 |
print(message)
|
71 |
|
72 |
if stream_output:
|
|
|
84 |
yield out
|
85 |
|
86 |
for partial_sentence in word_generator(message):
|
87 |
+
yield (
|
88 |
+
gr.Textbox.update(
|
89 |
+
value=partial_sentence, lines=inference_output_lines),
|
90 |
+
json.dumps(
|
91 |
+
list(range(len(partial_sentence.split()))), indent=2)
|
92 |
+
)
|
93 |
time.sleep(0.05)
|
94 |
|
95 |
return
|
96 |
time.sleep(1)
|
97 |
+
yield (
|
98 |
+
gr.Textbox.update(value=message, lines=inference_output_lines),
|
99 |
+
json.dumps(list(range(len(message.split()))), indent=2)
|
100 |
+
)
|
101 |
return
|
102 |
|
103 |
+
tokenizer = get_tokenizer(base_model_name)
|
104 |
+
model = get_model(base_model_name, lora_model_name)
|
|
|
|
|
105 |
|
106 |
inputs = tokenizer(prompt, return_tensors="pt")
|
107 |
input_ids = inputs["input_ids"].to(device)
|
|
|
121 |
"max_new_tokens": max_new_tokens,
|
122 |
}
|
123 |
|
124 |
+
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
125 |
+
if Global.should_stop_generating:
|
126 |
+
return True
|
127 |
+
return False
|
128 |
+
|
129 |
+
Global.should_stop_generating = False
|
130 |
+
generate_params.setdefault(
|
131 |
+
"stopping_criteria", transformers.StoppingCriteriaList()
|
132 |
+
)
|
133 |
+
generate_params["stopping_criteria"].append(
|
134 |
+
ui_generation_stopping_criteria
|
135 |
+
)
|
136 |
+
|
137 |
if stream_output:
|
138 |
# Stream the reply 1 token at a time.
|
139 |
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
|
|
165 |
raw_output = None
|
166 |
if show_raw:
|
167 |
raw_output = str(output)
|
168 |
+
response = prompter.get_response(decoded_output)
|
169 |
+
|
170 |
+
if Global.should_stop_generating:
|
171 |
+
return
|
172 |
+
|
173 |
+
yield (
|
174 |
+
gr.Textbox.update(
|
175 |
+
value=response, lines=inference_output_lines),
|
176 |
+
raw_output)
|
177 |
+
|
178 |
+
if Global.should_stop_generating:
|
179 |
+
# If the user stops the generation, and then clicks the
|
180 |
+
# generation button again, they may mysteriously landed
|
181 |
+
# here, in the previous, should-be-stopped generation
|
182 |
+
# function call, with the new generation function not be
|
183 |
+
# called at all. To workaround this, we yield a message
|
184 |
+
# and setting lines=1, and if the front-end JS detects
|
185 |
+
# that lines has been set to 1 (rows="1" in HTML),
|
186 |
+
# it will automatically click the generate button again
|
187 |
+
# (gr.Textbox.update() does not support updating
|
188 |
+
# elem_classes or elem_id).
|
189 |
+
# [WORKAROUND-UI01]
|
190 |
+
yield (
|
191 |
+
gr.Textbox.update(
|
192 |
+
value="Please retry", lines=1),
|
193 |
+
None)
|
194 |
return # early return for stream_output
|
195 |
|
196 |
# Without streaming
|
197 |
with torch.no_grad():
|
198 |
+
generation_output = model.generate(**generate_params)
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
s = generation_output.sequences[0]
|
200 |
output = tokenizer.decode(s)
|
201 |
raw_output = None
|
202 |
if show_raw:
|
203 |
raw_output = str(s)
|
204 |
+
|
205 |
+
response = prompter.get_response(output)
|
206 |
+
if Global.should_stop_generating:
|
207 |
+
return
|
208 |
+
|
209 |
+
yield (
|
210 |
+
gr.Textbox.update(value=response, lines=inference_output_lines),
|
211 |
+
raw_output)
|
212 |
|
213 |
except Exception as e:
|
214 |
raise gr.Error(e)
|
215 |
|
216 |
|
217 |
+
def handle_stop_generate():
|
218 |
+
Global.generation_force_stopped_at = time.time()
|
219 |
+
Global.should_stop_generating = True
|
220 |
+
|
221 |
+
|
222 |
def reload_selections(current_lora_model, current_prompt_template):
|
223 |
available_template_names = get_available_template_names()
|
224 |
available_template_names_with_none = available_template_names + ["None"]
|
|
|
240 |
gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
|
241 |
|
242 |
|
243 |
+
def handle_prompt_template_change(prompt_template, lora_model):
|
244 |
prompter = Prompter(prompt_template)
|
245 |
var_names = prompter.get_variable_names()
|
246 |
human_var_names = [' '.join(word.capitalize()
|
|
|
250 |
while len(gr_updates) < 8:
|
251 |
gr_updates.append(gr.Textbox.update(
|
252 |
label="Not Used", visible=False))
|
253 |
+
|
254 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
255 |
+
"", visible=False)
|
256 |
+
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
257 |
+
if lora_mode_info and isinstance(lora_mode_info, dict):
|
258 |
+
model_prompt_template = lora_mode_info.get("prompt_template")
|
259 |
+
if model_prompt_template and model_prompt_template != prompt_template:
|
260 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
261 |
+
f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
|
262 |
+
|
263 |
+
return [model_prompt_template_message_update] + gr_updates
|
264 |
+
|
265 |
+
|
266 |
+
def handle_lora_model_change(lora_model, prompt_template):
|
267 |
+
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
268 |
+
if not lora_mode_info:
|
269 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
270 |
+
|
271 |
+
if not isinstance(lora_mode_info, dict):
|
272 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
273 |
+
|
274 |
+
model_prompt_template = lora_mode_info.get("prompt_template")
|
275 |
+
if not model_prompt_template:
|
276 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
277 |
+
|
278 |
+
available_template_names = get_available_template_names()
|
279 |
+
if model_prompt_template in available_template_names:
|
280 |
+
return gr.Markdown.update("", visible=False), model_prompt_template
|
281 |
+
|
282 |
+
return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
|
283 |
|
284 |
|
285 |
def update_prompt_preview(prompt_template,
|
|
|
297 |
|
298 |
with gr.Blocks() as inference_ui_blocks:
|
299 |
with gr.Row():
|
300 |
+
with gr.Column(elem_id="inference_lora_model_group"):
|
301 |
+
model_prompt_template_message = gr.Markdown(
|
302 |
+
"", visible=False, elem_id="inference_lora_model_prompt_template_message")
|
303 |
+
lora_model = gr.Dropdown(
|
304 |
+
label="LoRA Model",
|
305 |
+
elem_id="inference_lora_model",
|
306 |
+
value="tloen/alpaca-lora-7b",
|
307 |
+
allow_custom_value=True,
|
308 |
+
)
|
309 |
prompt_template = gr.Dropdown(
|
310 |
label="Prompt Template",
|
311 |
elem_id="inference_prompt_template",
|
|
|
378 |
)
|
379 |
|
380 |
num_beams = gr.Slider(
|
381 |
+
minimum=1, maximum=5, value=2, step=1,
|
382 |
label="Beams",
|
383 |
elem_id="inference_beams"
|
384 |
)
|
|
|
418 |
with gr.Column(elem_id="inference_output_group_container"):
|
419 |
with gr.Column(elem_id="inference_output_group"):
|
420 |
inference_output = gr.Textbox(
|
421 |
+
lines=inference_output_lines, label="Output", elem_id="inference_output")
|
422 |
inference_output.style(show_copy_button=True)
|
423 |
with gr.Accordion(
|
424 |
"Raw Output",
|
|
|
446 |
)
|
447 |
things_that_might_timeout.append(reload_selections_event)
|
448 |
|
449 |
+
prompt_template_change_event = prompt_template.change(
|
450 |
+
fn=handle_prompt_template_change,
|
451 |
+
inputs=[prompt_template, lora_model],
|
452 |
+
outputs=[
|
453 |
+
model_prompt_template_message,
|
454 |
+
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
455 |
things_that_might_timeout.append(prompt_template_change_event)
|
456 |
|
457 |
+
lora_model_change_event = lora_model.change(
|
458 |
+
fn=handle_lora_model_change,
|
459 |
+
inputs=[lora_model, prompt_template],
|
460 |
+
outputs=[model_prompt_template_message, prompt_template])
|
461 |
+
things_that_might_timeout.append(lora_model_change_event)
|
462 |
+
|
463 |
generate_event = generate_btn.click(
|
464 |
+
fn=prepare_inference,
|
465 |
+
inputs=[lora_model],
|
466 |
+
outputs=[inference_output, inference_raw_output],
|
467 |
+
).then(
|
468 |
fn=do_inference,
|
469 |
inputs=[
|
470 |
lora_model,
|
|
|
483 |
outputs=[inference_output, inference_raw_output],
|
484 |
api_name="inference"
|
485 |
)
|
486 |
+
stop_btn.click(
|
487 |
+
fn=handle_stop_generate,
|
488 |
+
inputs=None,
|
489 |
+
outputs=None,
|
490 |
+
cancels=[generate_event]
|
491 |
+
)
|
492 |
|
493 |
update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
|
494 |
variable_0, variable_1, variable_2, variable_3,
|
|
|
661 |
return function (...args) {
|
662 |
const context = this;
|
663 |
clearTimeout(timeout);
|
664 |
+
const fn = () => {
|
665 |
+
if (document.querySelector('#inference_preview_prompt > .wrap:not(.hide)')) {
|
666 |
+
// Preview request is still loading, wait for 10ms and try again.
|
667 |
+
timeout = setTimeout(fn, 10);
|
668 |
+
return;
|
669 |
+
}
|
670 |
func.apply(context, args);
|
671 |
+
};
|
672 |
+
timeout = setTimeout(fn, wait);
|
673 |
};
|
674 |
}
|
675 |
|
|
|
704 |
});
|
705 |
}
|
706 |
}, 100);
|
707 |
+
|
708 |
+
// [WORKAROUND-UI01]
|
709 |
+
setTimeout(function () {
|
710 |
+
const inference_output_textarea = document.querySelector(
|
711 |
+
'#inference_output textarea'
|
712 |
+
);
|
713 |
+
if (!inference_output_textarea) return;
|
714 |
+
const observer = new MutationObserver(function () {
|
715 |
+
if (inference_output_textarea.getAttribute('rows') === '1') {
|
716 |
+
setTimeout(function () {
|
717 |
+
const inference_generate_btn = document.getElementById(
|
718 |
+
'inference_generate_btn'
|
719 |
+
);
|
720 |
+
if (inference_generate_btn) inference_generate_btn.click();
|
721 |
+
}, 10);
|
722 |
+
}
|
723 |
+
});
|
724 |
+
observer.observe(inference_output_textarea, {
|
725 |
+
attributes: true,
|
726 |
+
attributeFilter: ['rows'],
|
727 |
+
});
|
728 |
+
}, 100);
|
729 |
}
|
730 |
""")
|
llama_lora/ui/main_page.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from ..globals import Global
|
4 |
-
from ..models import get_model_with_lora
|
5 |
|
6 |
from .inference_ui import inference_ui
|
7 |
from .finetune_ui import finetune_ui
|
@@ -30,8 +29,8 @@ def main_page():
|
|
30 |
tokenizer_ui()
|
31 |
info = []
|
32 |
if Global.version:
|
33 |
-
info.append(f"LLaMA-LoRA `{Global.version}`")
|
34 |
-
info.append(f"Base model: `{Global.
|
35 |
if Global.ui_show_sys_info:
|
36 |
info.append(f"Data dir: `{Global.data_dir}`")
|
37 |
gr.Markdown(f"""
|
@@ -134,6 +133,41 @@ def main_page_custom_css():
|
|
134 |
/* text-transform: uppercase; */
|
135 |
}
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
#inference_prompt_box > *:first-child {
|
138 |
border-bottom-left-radius: 0;
|
139 |
border-bottom-right-radius: 0;
|
@@ -193,6 +227,8 @@ def main_page_custom_css():
|
|
193 |
#inference_raw_output > .wrap:first-child {
|
194 |
/* allow users to select text while generation is still in progress */
|
195 |
pointer-events: none;
|
|
|
|
|
196 |
}
|
197 |
|
198 |
/* position sticky */
|
@@ -266,12 +302,16 @@ def main_page_custom_css():
|
|
266 |
}
|
267 |
|
268 |
@media screen and (min-width: 640px) {
|
269 |
-
#inference_lora_model, #
|
|
|
270 |
border-top-right-radius: 0;
|
271 |
border-bottom-right-radius: 0;
|
272 |
border-right: 0;
|
273 |
margin-right: -16px;
|
274 |
}
|
|
|
|
|
|
|
275 |
|
276 |
#inference_prompt_template {
|
277 |
border-top-left-radius: 0;
|
@@ -301,7 +341,7 @@ def main_page_custom_css():
|
|
301 |
height: 42px !important;
|
302 |
min-width: 42px !important;
|
303 |
width: 42px !important;
|
304 |
-
z-index:
|
305 |
}
|
306 |
}
|
307 |
|
@@ -388,6 +428,9 @@ def main_page_custom_css():
|
|
388 |
white-space: pre-wrap;
|
389 |
}
|
390 |
|
|
|
|
|
|
|
391 |
|
392 |
@media screen and (max-width: 392px) {
|
393 |
#inference_lora_model, #finetune_template {
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from ..globals import Global
|
|
|
4 |
|
5 |
from .inference_ui import inference_ui
|
6 |
from .finetune_ui import finetune_ui
|
|
|
29 |
tokenizer_ui()
|
30 |
info = []
|
31 |
if Global.version:
|
32 |
+
info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
|
33 |
+
info.append(f"Base model: `{Global.default_base_model_name}`")
|
34 |
if Global.ui_show_sys_info:
|
35 |
info.append(f"Data dir: `{Global.data_dir}`")
|
36 |
gr.Markdown(f"""
|
|
|
133 |
/* text-transform: uppercase; */
|
134 |
}
|
135 |
|
136 |
+
#inference_lora_model_group {
|
137 |
+
border-radius: var(--block-radius);
|
138 |
+
background: var(--block-background-fill);
|
139 |
+
}
|
140 |
+
#inference_lora_model_group #inference_lora_model {
|
141 |
+
background: transparent;
|
142 |
+
}
|
143 |
+
#inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
|
144 |
+
padding-bottom: 28px;
|
145 |
+
}
|
146 |
+
#inference_lora_model_group > #inference_lora_model_prompt_template_message {
|
147 |
+
position: absolute;
|
148 |
+
bottom: 8px;
|
149 |
+
left: 20px;
|
150 |
+
z-index: 1;
|
151 |
+
font-size: 12px;
|
152 |
+
opacity: 0.7;
|
153 |
+
}
|
154 |
+
#inference_lora_model_group > #inference_lora_model_prompt_template_message p {
|
155 |
+
font-size: 12px;
|
156 |
+
}
|
157 |
+
#inference_lora_model_prompt_template_message > .wrap {
|
158 |
+
display: none;
|
159 |
+
}
|
160 |
+
#inference_lora_model > .wrap:first-child:not(.hide),
|
161 |
+
#inference_prompt_template > .wrap:first-child:not(.hide) {
|
162 |
+
opacity: 0.5;
|
163 |
+
}
|
164 |
+
#inference_lora_model_group, #inference_lora_model {
|
165 |
+
z-index: 60;
|
166 |
+
}
|
167 |
+
#inference_prompt_template {
|
168 |
+
z-index: 55;
|
169 |
+
}
|
170 |
+
|
171 |
#inference_prompt_box > *:first-child {
|
172 |
border-bottom-left-radius: 0;
|
173 |
border-bottom-right-radius: 0;
|
|
|
227 |
#inference_raw_output > .wrap:first-child {
|
228 |
/* allow users to select text while generation is still in progress */
|
229 |
pointer-events: none;
|
230 |
+
|
231 |
+
padding: 12px !important;
|
232 |
}
|
233 |
|
234 |
/* position sticky */
|
|
|
302 |
}
|
303 |
|
304 |
@media screen and (min-width: 640px) {
|
305 |
+
#inference_lora_model, #inference_lora_model_group,
|
306 |
+
#finetune_template {
|
307 |
border-top-right-radius: 0;
|
308 |
border-bottom-right-radius: 0;
|
309 |
border-right: 0;
|
310 |
margin-right: -16px;
|
311 |
}
|
312 |
+
#inference_lora_model_group #inference_lora_model {
|
313 |
+
box-shadow: var(--block-shadow);
|
314 |
+
}
|
315 |
|
316 |
#inference_prompt_template {
|
317 |
border-top-left-radius: 0;
|
|
|
341 |
height: 42px !important;
|
342 |
min-width: 42px !important;
|
343 |
width: 42px !important;
|
344 |
+
z-index: 61;
|
345 |
}
|
346 |
}
|
347 |
|
|
|
428 |
white-space: pre-wrap;
|
429 |
}
|
430 |
|
431 |
+
#finetune_max_seq_length {
|
432 |
+
flex: 2;
|
433 |
+
}
|
434 |
|
435 |
@media screen and (max-width: 392px) {
|
436 |
#inference_lora_model, #finetune_template {
|
llama_lora/ui/tokenizer_ui.py
CHANGED
@@ -7,11 +7,12 @@ from ..models import get_tokenizer
|
|
7 |
|
8 |
|
9 |
def handle_decode(encoded_tokens_json):
|
|
|
10 |
try:
|
11 |
encoded_tokens = json.loads(encoded_tokens_json)
|
12 |
if Global.ui_dev_mode:
|
13 |
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
14 |
-
tokenizer = get_tokenizer()
|
15 |
decoded_tokens = tokenizer.decode(encoded_tokens)
|
16 |
return decoded_tokens, gr.Markdown.update("", visible=False)
|
17 |
except Exception as e:
|
@@ -19,10 +20,11 @@ def handle_decode(encoded_tokens_json):
|
|
19 |
|
20 |
|
21 |
def handle_encode(decoded_tokens):
|
|
|
22 |
try:
|
23 |
if Global.ui_dev_mode:
|
24 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
25 |
-
tokenizer = get_tokenizer()
|
26 |
result = tokenizer(decoded_tokens)
|
27 |
encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
|
28 |
return encoded_tokens_json, gr.Markdown.update("", visible=False)
|
|
|
7 |
|
8 |
|
9 |
def handle_decode(encoded_tokens_json):
|
10 |
+
base_model_name = Global.default_base_model_name
|
11 |
try:
|
12 |
encoded_tokens = json.loads(encoded_tokens_json)
|
13 |
if Global.ui_dev_mode:
|
14 |
return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
|
15 |
+
tokenizer = get_tokenizer(base_model_name)
|
16 |
decoded_tokens = tokenizer.decode(encoded_tokens)
|
17 |
return decoded_tokens, gr.Markdown.update("", visible=False)
|
18 |
except Exception as e:
|
|
|
20 |
|
21 |
|
22 |
def handle_encode(decoded_tokens):
|
23 |
+
base_model_name = Global.default_base_model_name
|
24 |
try:
|
25 |
if Global.ui_dev_mode:
|
26 |
return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
|
27 |
+
tokenizer = get_tokenizer(base_model_name)
|
28 |
result = tokenizer(decoded_tokens)
|
29 |
encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
|
30 |
return encoded_tokens_json, gr.Markdown.update("", visible=False)
|
llama_lora/utils/data.py
CHANGED
@@ -52,6 +52,22 @@ def get_path_of_available_lora_model(name):
|
|
52 |
return None
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def get_dataset_content(name):
|
56 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
57 |
if not os.path.exists(file_name):
|
|
|
52 |
return None
|
53 |
|
54 |
|
55 |
+
def get_info_of_available_lora_model(name):
|
56 |
+
try:
|
57 |
+
if "/" in name:
|
58 |
+
return None
|
59 |
+
path_of_available_lora_model = get_path_of_available_lora_model(
|
60 |
+
name)
|
61 |
+
if not path_of_available_lora_model:
|
62 |
+
return None
|
63 |
+
|
64 |
+
with open(os.path.join(path_of_available_lora_model, "info.json"), "r") as json_file:
|
65 |
+
return json.load(json_file)
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
return None
|
69 |
+
|
70 |
+
|
71 |
def get_dataset_content(name):
|
72 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
73 |
if not os.path.exists(file_name):
|
llama_lora/utils/lru_cache.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
|
4 |
+
class LRUCache:
|
5 |
+
def __init__(self, capacity=5):
|
6 |
+
self.cache = OrderedDict()
|
7 |
+
self.capacity = capacity
|
8 |
+
|
9 |
+
def get(self, key):
|
10 |
+
if key in self.cache:
|
11 |
+
# Move the accessed item to the end of the OrderedDict
|
12 |
+
self.cache.move_to_end(key)
|
13 |
+
return self.cache[key]
|
14 |
+
return None
|
15 |
+
|
16 |
+
def set(self, key, value):
|
17 |
+
if key in self.cache:
|
18 |
+
# If the key already exists, update its value
|
19 |
+
self.cache[key] = value
|
20 |
+
else:
|
21 |
+
# If the cache has reached its capacity, remove the least recently used item
|
22 |
+
if len(self.cache) >= self.capacity:
|
23 |
+
self.cache.popitem(last=False)
|
24 |
+
self.cache[key] = value
|
25 |
+
|
26 |
+
def clear(self):
|
27 |
+
self.cache.clear()
|
28 |
+
|
29 |
+
def prepare_to_set(self):
|
30 |
+
if len(self.cache) >= self.capacity:
|
31 |
+
self.cache.popitem(last=False)
|
requirements.lock.txt
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.18.0
|
2 |
+
aiofiles==23.1.0
|
3 |
+
aiohttp==3.8.4
|
4 |
+
aiosignal==1.3.1
|
5 |
+
altair==4.2.2
|
6 |
+
anyio==3.6.2
|
7 |
+
appdirs==1.4.4
|
8 |
+
asttokens==2.2.1
|
9 |
+
async-timeout==4.0.2
|
10 |
+
attrs==22.2.0
|
11 |
+
backcall==0.2.0
|
12 |
+
bitsandbytes==0.37.2
|
13 |
+
black==23.3.0
|
14 |
+
charset-normalizer==3.1.0
|
15 |
+
click==8.1.3
|
16 |
+
contourpy==1.0.7
|
17 |
+
cycler==0.11.0
|
18 |
+
datasets==2.11.0
|
19 |
+
decorator==5.1.1
|
20 |
+
dill==0.3.6
|
21 |
+
entrypoints==0.4
|
22 |
+
exceptiongroup==1.1.1
|
23 |
+
executing==1.2.0
|
24 |
+
fastapi==0.95.0
|
25 |
+
ffmpy==0.3.0
|
26 |
+
filelock==3.11.0
|
27 |
+
fire==0.5.0
|
28 |
+
fonttools==4.39.3
|
29 |
+
frozenlist==1.3.3
|
30 |
+
fsspec==2023.3.0
|
31 |
+
gradio==3.24.1
|
32 |
+
gradio_client==0.0.8
|
33 |
+
h11==0.14.0
|
34 |
+
httpcore==0.16.3
|
35 |
+
httpx==0.23.3
|
36 |
+
huggingface-hub==0.13.4
|
37 |
+
idna==3.4
|
38 |
+
importlib-metadata==6.2.0
|
39 |
+
importlib-resources==5.12.0
|
40 |
+
iniconfig==2.0.0
|
41 |
+
ipython==8.12.0
|
42 |
+
jedi==0.18.2
|
43 |
+
Jinja2==3.1.2
|
44 |
+
jsonschema==4.17.3
|
45 |
+
kiwisolver==1.4.4
|
46 |
+
linkify-it-py==2.0.0
|
47 |
+
llvmlite==0.39.1
|
48 |
+
loralib==0.1.1
|
49 |
+
markdown-it-py==2.2.0
|
50 |
+
MarkupSafe==2.1.2
|
51 |
+
matplotlib==3.7.1
|
52 |
+
matplotlib-inline==0.1.6
|
53 |
+
mdit-py-plugins==0.3.3
|
54 |
+
mdurl==0.1.2
|
55 |
+
mpmath==1.3.0
|
56 |
+
multidict==6.0.4
|
57 |
+
multiprocess==0.70.14
|
58 |
+
mypy-extensions==1.0.0
|
59 |
+
networkx==3.1
|
60 |
+
numba==0.56.4
|
61 |
+
numpy==1.23.5
|
62 |
+
nvidia-ml-py3==7.352.0
|
63 |
+
orjson==3.8.9
|
64 |
+
packaging==23.0
|
65 |
+
pandas==2.0.0
|
66 |
+
parso==0.8.3
|
67 |
+
pathspec==0.11.1
|
68 |
+
peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
|
69 |
+
pexpect==4.8.0
|
70 |
+
pickleshare==0.7.5
|
71 |
+
Pillow==9.3.0
|
72 |
+
pkgutil_resolve_name==1.3.10
|
73 |
+
platformdirs==3.2.0
|
74 |
+
pluggy==1.0.0
|
75 |
+
prompt-toolkit==3.0.38
|
76 |
+
psutil==5.9.4
|
77 |
+
ptyprocess==0.7.0
|
78 |
+
pure-eval==0.2.2
|
79 |
+
pyarrow==11.0.0
|
80 |
+
pydantic==1.10.7
|
81 |
+
pydub==0.25.1
|
82 |
+
Pygments==2.14.0
|
83 |
+
pyparsing==3.0.9
|
84 |
+
pyrsistent==0.19.3
|
85 |
+
pytest==7.2.2
|
86 |
+
python-dateutil==2.8.2
|
87 |
+
python-multipart==0.0.6
|
88 |
+
pytz==2023.3
|
89 |
+
PyYAML==6.0
|
90 |
+
Random-Word==1.0.11
|
91 |
+
regex==2023.3.23
|
92 |
+
requests==2.28.2
|
93 |
+
responses==0.18.0
|
94 |
+
rfc3986==1.5.0
|
95 |
+
semantic-version==2.10.0
|
96 |
+
sentencepiece==0.1.97
|
97 |
+
six==1.16.0
|
98 |
+
sniffio==1.3.0
|
99 |
+
stack-data==0.6.2
|
100 |
+
starlette==0.26.1
|
101 |
+
sympy==1.11.1
|
102 |
+
termcolor==2.2.0
|
103 |
+
tokenize-rt==5.0.0
|
104 |
+
tokenizers==0.13.3
|
105 |
+
tomli==2.0.1
|
106 |
+
toolz==0.12.0
|
107 |
+
torch==2.0.0
|
108 |
+
tqdm==4.65.0
|
109 |
+
traitlets==5.9.0
|
110 |
+
transformers @ git+https://github.com/huggingface/transformers.git@3f96e0b4e483c4c7d4ec9dcdc24b0b0cdf31ea5c
|
111 |
+
typing_extensions==4.5.0
|
112 |
+
tzdata==2023.3
|
113 |
+
uc-micro-py==1.0.1
|
114 |
+
urllib3==1.26.15
|
115 |
+
uvicorn==0.21.1
|
116 |
+
wcwidth==0.2.6
|
117 |
+
websockets==11.0.1
|
118 |
+
xxhash==3.2.0
|
119 |
+
yarl==1.8.2
|
120 |
+
zipp==3.15.0
|
templates/user_and_ai.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"description": "Unhelpful AI assistant.",
|
3 |
+
"variables": ["instruction"],
|
4 |
+
"prompt": "### User:\n{instruction}\n\n### AI:\n",
|
5 |
+
"default": "prompt",
|
6 |
+
"response_split": "### AI:"
|
7 |
+
}
|