zetavg commited on
Commit
45501f4
·
unverified ·
2 Parent(s): 91cb7fd a1771a7

Merge branch 'main' into hf-ui-demo

Browse files
LLaMA_LoRA.ipynb CHANGED
@@ -6,7 +6,6 @@
6
  "provenance": [],
7
  "private_outputs": true,
8
  "toc_visible": true,
9
- "authorship_tag": "ABX9TyMHMc4PwWLbRlhFol+WRzoT",
10
  "include_colab_link": true
11
  },
12
  "kernelspec": {
@@ -27,13 +26,13 @@
27
  "colab_type": "text"
28
  },
29
  "source": [
30
- "<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
31
  ]
32
  },
33
  {
34
  "cell_type": "markdown",
35
  "source": [
36
- "# 🦙🎛️ LLaMA-LoRA\n",
37
  "\n",
38
  "TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
39
  ],
@@ -55,6 +54,26 @@
55
  "execution_count": null,
56
  "outputs": []
57
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  {
59
  "cell_type": "markdown",
60
  "source": [
@@ -72,9 +91,9 @@
72
  "# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
73
  "# @markdown Project settings.\n",
74
  "\n",
75
- "# @markdown The URL of the LLaMA-LoRA project<br>&nbsp;&nbsp;(default: `https://github.com/zetavg/llama-lora.git`):\n",
76
- "llama_lora_project_url = \"https://github.com/zetavg/llama-lora.git\" # @param {type:\"string\"}\n",
77
- "# @markdown The branch to use for LLaMA-LoRA project:\n",
78
  "llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
79
  "\n",
80
  "# # @markdown Forces the local directory to be updated by the remote branch:\n",
@@ -97,7 +116,7 @@
97
  "# @markdown You can customize the location of the stored data here.\n",
98
  "\n",
99
  "# @markdown The folder in Google Drive where Colab Notebook data are stored<br />&nbsp;&nbsp;**(WARNING: The content of this folder will be modified by this notebook)**:\n",
100
- "google_drive_folder = \"Colab Data/LLaMA LoRA\" # @param {type:\"string\"}\n",
101
  "# google_drive_colab_data_folder = \"Colab Notebooks/Notebook Data\"\n",
102
  "\n",
103
  "# Where Google Drive will be mounted in the Colab runtime.\n",
@@ -220,7 +239,7 @@
220
  "source": [
221
  "![ ! -d llama_lora ] && git clone -b {llama_lora_project_branch} --filter=tree:0 {llama_lora_project_url} llama_lora\n",
222
  "!cd llama_lora && git add --all && git stash && git fetch origin {llama_lora_project_branch} && git checkout {llama_lora_project_branch} && git reset origin/{llama_lora_project_branch} --hard\n",
223
- "![ ! -f llama-lora-requirements-installed ] && cd llama_lora && pip install -r requirements.txt && touch ../llama-lora-requirements-installed"
224
  ],
225
  "metadata": {
226
  "id": "JGYz2VDoAzC8"
@@ -262,7 +281,7 @@
262
  "\n",
263
  "# Set Configs\n",
264
  "from llama_lora.llama_lora.globals import Global\n",
265
- "Global.base_model = base_model\n",
266
  "data_dir_realpath = !realpath ./data\n",
267
  "Global.data_dir = data_dir_realpath[0]\n",
268
  "Global.load_8bit = True\n",
@@ -270,12 +289,11 @@
270
  "# Prepare Data Dir\n",
271
  "import os\n",
272
  "from llama_lora.llama_lora.utils.data import init_data_dir\n",
273
- "init_data_dir()\n",
274
  "\n",
275
  "# Load the Base Model\n",
276
- "from llama_lora.llama_lora.models import load_base_model\n",
277
- "load_base_model()\n",
278
- "print(f\"Base model loaded: '{Global.base_model}'.\")"
279
  ],
280
  "metadata": {
281
  "id": "Yf6g248ylteP"
 
6
  "provenance": [],
7
  "private_outputs": true,
8
  "toc_visible": true,
 
9
  "include_colab_link": true
10
  },
11
  "kernelspec": {
 
26
  "colab_type": "text"
27
  },
28
  "source": [
29
+ "<a href=\"https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
  ]
31
  },
32
  {
33
  "cell_type": "markdown",
34
  "source": [
35
+ "# 🦙🎛️ LLaMA-LoRA Tuner\n",
36
  "\n",
37
  "TL;DR: **Runtime > Run All** (`⌘/Ctrl+F9`). Takes about 5 minutes to start. You will be promped to authorize Google Drive access."
38
  ],
 
54
  "execution_count": null,
55
  "outputs": []
56
  },
57
+ {
58
+ "cell_type": "code",
59
+ "source": [
60
+ "# @title A small workaround { display-mode: \"form\" }\n",
61
+ "# @markdown Don't panic if you see an error here. Just click the `RESTART RUNTIME` button in the output below, then Run All again.\n",
62
+ "# @markdown The error will disappear on the next run.\n",
63
+ "!pip install Pillow==9.3.0\n",
64
+ "import PIL\n",
65
+ "major, minor = map(float, PIL.__version__.split(\".\")[:2])\n",
66
+ "version_float = major + minor / 10**len(str(minor))\n",
67
+ "print(version_float)\n",
68
+ "if version_float < 9.003:\n",
69
+ " raise Exception(\"Restart the runtime by clicking the 'RESTART RUNTIME' button above (or Runtime > Restart Runtime).\")"
70
+ ],
71
+ "metadata": {
72
+ "id": "XcJ4WO3KhOX1"
73
+ },
74
+ "execution_count": null,
75
+ "outputs": []
76
+ },
77
  {
78
  "cell_type": "markdown",
79
  "source": [
 
91
  "# @title Git/Project { display-mode: \"form\", run: \"auto\" }\n",
92
  "# @markdown Project settings.\n",
93
  "\n",
94
+ "# @markdown The URL of the LLaMA-LoRA-Tuner project<br>&nbsp;&nbsp;(default: `https://github.com/zetavg/LLaMA-LoRA-Tuner.git`):\n",
95
+ "llama_lora_project_url = \"https://github.com/zetavg/LLaMA-LoRA-Tuner.git\" # @param {type:\"string\"}\n",
96
+ "# @markdown The branch to use for LLaMA-LoRA-Tuner project:\n",
97
  "llama_lora_project_branch = \"main\" # @param {type:\"string\"}\n",
98
  "\n",
99
  "# # @markdown Forces the local directory to be updated by the remote branch:\n",
 
116
  "# @markdown You can customize the location of the stored data here.\n",
117
  "\n",
118
  "# @markdown The folder in Google Drive where Colab Notebook data are stored<br />&nbsp;&nbsp;**(WARNING: The content of this folder will be modified by this notebook)**:\n",
119
+ "google_drive_folder = \"Colab Data/LLaMA-LoRA Tuner\" # @param {type:\"string\"}\n",
120
  "# google_drive_colab_data_folder = \"Colab Notebooks/Notebook Data\"\n",
121
  "\n",
122
  "# Where Google Drive will be mounted in the Colab runtime.\n",
 
239
  "source": [
240
  "![ ! -d llama_lora ] && git clone -b {llama_lora_project_branch} --filter=tree:0 {llama_lora_project_url} llama_lora\n",
241
  "!cd llama_lora && git add --all && git stash && git fetch origin {llama_lora_project_branch} && git checkout {llama_lora_project_branch} && git reset origin/{llama_lora_project_branch} --hard\n",
242
+ "![ ! -f llama-lora-requirements-installed ] && cd llama_lora && pip install -r requirements.lock.txt && touch ../llama-lora-requirements-installed"
243
  ],
244
  "metadata": {
245
  "id": "JGYz2VDoAzC8"
 
281
  "\n",
282
  "# Set Configs\n",
283
  "from llama_lora.llama_lora.globals import Global\n",
284
+ "Global.default_base_model_name = base_model\n",
285
  "data_dir_realpath = !realpath ./data\n",
286
  "Global.data_dir = data_dir_realpath[0]\n",
287
  "Global.load_8bit = True\n",
 
289
  "# Prepare Data Dir\n",
290
  "import os\n",
291
  "from llama_lora.llama_lora.utils.data import init_data_dir\n",
292
+ "init_data_dir()",
293
  "\n",
294
  "# Load the Base Model\n",
295
+ "from llama_lora.llama_lora.models import prepare_base_model\n",
296
+ "prepare_base_model()\n"
 
297
  ],
298
  "metadata": {
299
  "id": "Yf6g248ylteP"
README.md CHANGED
@@ -20,16 +20,19 @@ git push -f hf-ui-demo hf-ui-demo:main
20
 
21
  ---
22
 
23
- # 🦙🎛️ LLaMA-LoRA
24
 
25
- <a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
 
 
26
 
27
  Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
28
 
29
 
30
  ## Features
31
 
32
- * [1-click up and running in Google Colab](https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb).
 
 
33
  * Loads and stores data in Google Drive.
34
  * Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/A3kb4VkDWyY"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230272844-09f7a35b-46bf-4101-b15d-4ddf243b8bef.gif" /></a>
35
  * Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/5Db9U8PsaUk"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230277315-9a91d983-1690-4594-9d54-912eda8963ee.gif" /></a>
@@ -47,7 +50,7 @@ There are various ways to run this app:
47
 
48
  ### Run On Google Colab
49
 
50
- Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
51
 
52
  You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
53
 
@@ -58,10 +61,10 @@ After approximately 5 minutes of running, you will see the public URL in the out
58
  After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
59
 
60
  ```yaml
61
- # llama-lora-multitool.yaml
62
 
63
  resources:
64
- accelerators: A10:1 # 1x NVIDIA A10 GPU
65
  cloud: lambda # Optional; if left out, SkyPilot will automatically pick the cheapest cloud.
66
 
67
  file_mounts:
@@ -69,27 +72,27 @@ file_mounts:
69
  # (to store train datasets trained models)
70
  # See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
71
  /data:
72
- name: llama-lora-multitool-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
73
- store: gcs # Could be either of [s3, gcs]
74
  mode: MOUNT
75
 
76
- # Clone the LLaMA-LoRA repo and install its dependencies.
77
  setup: |
78
- git clone https://github.com/zetavg/LLaMA-LoRA.git llama_lora
79
- cd llama_lora && pip install -r requirements.txt
80
  cd ..
81
  echo 'Dependencies installed.'
82
 
83
  # Start the app.
84
  run: |
85
  echo 'Starting...'
86
- python llama_lora/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
87
  ```
88
 
89
  Then launch a cluster to run the task:
90
 
91
  ```
92
- sky launch -c llama-lora-multitool llama-lora-multitool.yaml
93
  ```
94
 
95
  `-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
@@ -106,13 +109,13 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina
106
  <summary>Prepare environment with conda</summary>
107
 
108
  ```bash
109
- conda create -y -n llama-lora-multitool python=3.8
110
- conda activate llama-lora-multitool
111
  ```
112
  </details>
113
 
114
  ```bash
115
- pip install -r requirements.txt
116
  python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
117
  ```
118
 
 
20
 
21
  ---
22
 
 
23
 
24
+ # 🦙🎛️ LLaMA-LoRA Tuner
25
+
26
+ <a href="https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
27
 
28
  Making evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA) easy.
29
 
30
 
31
  ## Features
32
 
33
+ **[See a demo on Hugging Face](https://huggingface.co/spaces/zetavg/LLaMA-LoRA-UI-Demo)** **Only serves UI demonstration. To try training or text generation, [run on Colab](#run-on-google-colab).*
34
+
35
+ * **[1-click up and running in Google Colab](#run-on-google-colab)** with a standard GPU runtime.
36
  * Loads and stores data in Google Drive.
37
  * Evaluate various LLaMA LoRA models stored in your folder or from Hugging Face.<br /><a href="https://youtu.be/A3kb4VkDWyY"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230272844-09f7a35b-46bf-4101-b15d-4ddf243b8bef.gif" /></a>
38
  * Fine-tune LLaMA models with different prompt templates and training dataset format.<br /><a href="https://youtu.be/5Db9U8PsaUk"><img width="640px" src="https://user-images.githubusercontent.com/3784687/230277315-9a91d983-1690-4594-9d54-912eda8963ee.gif" /></a>
 
50
 
51
  ### Run On Google Colab
52
 
53
+ Open [this Colab Notebook](https://colab.research.google.com/github/zetavg/LLaMA-LoRA-Tuner/blob/main/LLaMA_LoRA.ipynb) and select **Runtime > Run All** (`⌘/Ctrl+F9`).
54
 
55
  You will be prompted to authorize Google Drive access, as Google Drive will be used to store your data. See the "Config"/"Google Drive" section for settings and more info.
56
 
 
61
  After following the [installation guide of SkyPilot](https://skypilot.readthedocs.io/en/latest/getting-started/installation.html), create a `.yaml` to define a task for running the app:
62
 
63
  ```yaml
64
+ # llama-lora-tuner.yaml
65
 
66
  resources:
67
+ accelerators: A10:1 # 1x NVIDIA A10 GPU, about US$ 0.6 / hr on Lambda Cloud.
68
  cloud: lambda # Optional; if left out, SkyPilot will automatically pick the cheapest cloud.
69
 
70
  file_mounts:
 
72
  # (to store train datasets trained models)
73
  # See https://skypilot.readthedocs.io/en/latest/reference/storage.html for details.
74
  /data:
75
+ name: llama-lora-tuner-data # Make sure this name is unique or you own this bucket. If it does not exists, SkyPilot will try to create a bucket with this name.
76
+ store: s3 # Could be either of [s3, gcs]
77
  mode: MOUNT
78
 
79
+ # Clone the LLaMA-LoRA Tuner repo and install its dependencies.
80
  setup: |
81
+ git clone https://github.com/zetavg/LLaMA-LoRA-Tuner.git llama_lora_tuner
82
+ cd llama_lora_tuner && pip install -r requirements.lock.txt
83
  cd ..
84
  echo 'Dependencies installed.'
85
 
86
  # Start the app.
87
  run: |
88
  echo 'Starting...'
89
+ python llama_lora_tuner/app.py --data_dir='/data' --base_model='decapoda-research/llama-7b-hf' --share
90
  ```
91
 
92
  Then launch a cluster to run the task:
93
 
94
  ```
95
+ sky launch -c llama-lora-tuner llama-lora-tuner.yaml
96
  ```
97
 
98
  `-c ...` is an optional flag to specify a cluster name. If not specified, SkyPilot will automatically generate one.
 
109
  <summary>Prepare environment with conda</summary>
110
 
111
  ```bash
112
+ conda create -y python=3.8 -n llama-lora-tuner
113
+ conda activate llama-lora-tuner
114
  ```
115
  </details>
116
 
117
  ```bash
118
+ pip install -r requirements.lock.txt
119
  python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
120
  ```
121
 
app.py CHANGED
@@ -16,6 +16,7 @@ def main(
16
  # Allows to listen on all interfaces by providing '0.0.0.0'.
17
  server_name: str = "127.0.0.1",
18
  share: bool = False,
 
19
  ui_show_sys_info: bool = True,
20
  ui_dev_mode: bool = False,
21
  ):
@@ -29,7 +30,7 @@ def main(
29
  data_dir
30
  ), "Please specify a --data_dir, e.g. --data_dir='./data'"
31
 
32
- Global.base_model = base_model
33
  Global.data_dir = os.path.abspath(data_dir)
34
  Global.load_8bit = load_8bit
35
 
 
16
  # Allows to listen on all interfaces by providing '0.0.0.0'.
17
  server_name: str = "127.0.0.1",
18
  share: bool = False,
19
+ skip_loading_base_model: bool = False,
20
  ui_show_sys_info: bool = True,
21
  ui_dev_mode: bool = False,
22
  ):
 
30
  data_dir
31
  ), "Please specify a --data_dir, e.g. --data_dir='./data'"
32
 
33
+ Global.default_base_model_name = base_model
34
  Global.data_dir = os.path.abspath(data_dir)
35
  Global.load_8bit = load_8bit
36
 
llama_lora/globals.py CHANGED
@@ -6,18 +6,17 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
  from numba import cuda
7
  import nvidia_smi
8
 
 
9
  from .lib.finetune import train
10
 
11
 
12
  class Global:
13
  version = None
14
 
15
- base_model: str = ""
16
  data_dir: str = ""
17
  load_8bit: bool = False
18
 
19
- loaded_tokenizer: Any = None
20
- loaded_base_model: Any = None
21
 
22
  # Functions
23
  train_fn: Any = train
@@ -25,8 +24,15 @@ class Global:
25
  # Training Control
26
  should_stop_training = False
27
 
 
 
 
 
28
  # Model related
29
- model_has_been_used = False
 
 
 
30
 
31
  # GPU Info
32
  gpu_cc = None # GPU compute capability
@@ -35,7 +41,7 @@ class Global:
35
  gpu_total_memory = None
36
 
37
  # UI related
38
- ui_title: str = "LLaMA-LoRA"
39
  ui_emoji: str = "🦙🎛️"
40
  ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
41
  ui_show_sys_info: bool = True
 
6
  from numba import cuda
7
  import nvidia_smi
8
 
9
+ from .utils.lru_cache import LRUCache
10
  from .lib.finetune import train
11
 
12
 
13
  class Global:
14
  version = None
15
 
 
16
  data_dir: str = ""
17
  load_8bit: bool = False
18
 
19
+ default_base_model_name: str = ""
 
20
 
21
  # Functions
22
  train_fn: Any = train
 
24
  # Training Control
25
  should_stop_training = False
26
 
27
+ # Generation Control
28
+ should_stop_generating = False
29
+ generation_force_stopped_at = None
30
+
31
  # Model related
32
+ loaded_models = LRUCache(1)
33
+ loaded_tokenizers = LRUCache(1)
34
+ new_base_model_that_is_ready_to_be_used = None
35
+ name_of_new_base_model_that_is_ready_to_be_used = None
36
 
37
  # GPU Info
38
  gpu_cc = None # GPU compute capability
 
41
  gpu_total_memory = None
42
 
43
  # UI related
44
+ ui_title: str = "LLaMA-LoRA Tuner"
45
  ui_emoji: str = "🦙🎛️"
46
  ui_subtitle: str = "Toolkit for evaluating and fine-tuning LLaMA models with low-rank adaptation (LoRA)."
47
  ui_show_sys_info: bool = True
llama_lora/lib/finetune.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  import sys
3
  from typing import Any, List
4
 
 
 
5
  import fire
6
  import torch
7
  import transformers
@@ -47,6 +49,10 @@ def train(
47
  # logging
48
  callbacks: List[Any] = []
49
  ):
 
 
 
 
50
  device_map = "auto"
51
  world_size = int(os.environ.get("WORLD_SIZE", 1))
52
  ddp = world_size != 1
@@ -202,6 +208,12 @@ def train(
202
  ),
203
  callbacks=callbacks,
204
  )
 
 
 
 
 
 
205
  model.config.use_cache = False
206
 
207
  old_state_dict = model.state_dict
@@ -214,9 +226,16 @@ def train(
214
  if torch.__version__ >= "2" and sys.platform != "win32":
215
  model = torch.compile(model)
216
 
217
- result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
218
 
219
  model.save_pretrained(output_dir)
220
  print(f"Model saved to {output_dir}.")
221
 
222
- return result
 
 
 
 
 
 
 
 
2
  import sys
3
  from typing import Any, List
4
 
5
+ import json
6
+
7
  import fire
8
  import torch
9
  import transformers
 
49
  # logging
50
  callbacks: List[Any] = []
51
  ):
52
+ if os.path.exists(output_dir):
53
+ if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
54
+ raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
55
+
56
  device_map = "auto"
57
  world_size = int(os.environ.get("WORLD_SIZE", 1))
58
  ddp = world_size != 1
 
208
  ),
209
  callbacks=callbacks,
210
  )
211
+
212
+ if not os.path.exists(output_dir):
213
+ os.makedirs(output_dir)
214
+ with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
215
+ json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
216
+
217
  model.config.use_cache = False
218
 
219
  old_state_dict = model.state_dict
 
226
  if torch.__version__ >= "2" and sys.platform != "win32":
227
  model = torch.compile(model)
228
 
229
+ train_output = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
230
 
231
  model.save_pretrained(output_dir)
232
  print(f"Model saved to {output_dir}.")
233
 
234
+ with open(os.path.join(output_dir, "trainer_log_history.jsonl"), 'w') as trainer_log_history_jsonl_file:
235
+ trainer_log_history = "\n".join([json.dumps(line) for line in trainer.state.log_history])
236
+ trainer_log_history_jsonl_file.write(trainer_log_history)
237
+
238
+ with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
239
+ json.dump(train_output, train_output_json_file, indent=2)
240
+
241
+ return train_output
llama_lora/models.py CHANGED
@@ -3,9 +3,8 @@ import sys
3
  import gc
4
 
5
  import torch
6
- import transformers
7
  from peft import PeftModel
8
- from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
9
 
10
  from .globals import Global
11
 
@@ -23,84 +22,140 @@ def get_device():
23
  pass
24
 
25
 
26
- device = get_device()
27
-
28
-
29
- def get_base_model():
30
- load_base_model()
31
- return Global.loaded_base_model
32
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def get_model_with_lora(lora_weights: str = "tloen/alpaca-lora-7b"):
35
- Global.model_has_been_used = True
36
 
37
  if device == "cuda":
38
- model = PeftModel.from_pretrained(
39
- get_base_model(),
40
- lora_weights,
41
  torch_dtype=torch.float16,
42
- device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
 
 
43
  )
44
  elif device == "mps":
45
- model = PeftModel.from_pretrained(
46
- get_base_model(),
47
- lora_weights,
48
  device_map={"": device},
49
  torch_dtype=torch.float16,
50
  )
51
  else:
52
- model = PeftModel.from_pretrained(
53
- get_base_model(),
54
- lora_weights,
55
- device_map={"": device},
56
  )
57
 
58
- model.config.pad_token_id = get_tokenizer().pad_token_id = 0
59
  model.config.bos_token_id = 1
60
  model.config.eos_token_id = 2
61
 
62
- if not Global.load_8bit:
63
- model.half() # seems to fix bugs for some users.
64
-
65
- model.eval()
66
- if torch.__version__ >= "2" and sys.platform != "win32":
67
- model = torch.compile(model)
68
  return model
69
 
70
 
71
- def get_tokenizer():
72
- load_base_model()
73
- return Global.loaded_tokenizer
 
 
 
 
 
 
 
74
 
 
75
 
76
- def load_base_model():
 
 
 
77
  if Global.ui_dev_mode:
78
  return
79
 
80
- if Global.loaded_tokenizer is None:
81
- Global.loaded_tokenizer = LlamaTokenizer.from_pretrained(
82
- Global.base_model
83
- )
84
- if Global.loaded_base_model is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  if device == "cuda":
86
- Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
87
- Global.base_model,
88
- load_in_8bit=Global.load_8bit,
89
  torch_dtype=torch.float16,
90
- # device_map="auto",
91
- device_map={'': 0}, # ? https://github.com/tloen/alpaca-lora/issues/21
92
  )
93
  elif device == "mps":
94
- Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
95
- Global.base_model,
 
96
  device_map={"": device},
97
  torch_dtype=torch.float16,
98
  )
99
  else:
100
- Global.loaded_base_model = LlamaForCausalLM.from_pretrained(
101
- Global.base_model, device_map={"": device}, low_cpu_mem_usage=True
 
 
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  def clear_cache():
106
  gc.collect()
@@ -111,17 +166,6 @@ def clear_cache():
111
 
112
 
113
  def unload_models():
114
- del Global.loaded_base_model
115
- Global.loaded_base_model = None
116
-
117
- del Global.loaded_tokenizer
118
- Global.loaded_tokenizer = None
119
-
120
  clear_cache()
121
-
122
- Global.model_has_been_used = False
123
-
124
-
125
- def unload_models_if_already_used():
126
- if Global.model_has_been_used:
127
- unload_models()
 
3
  import gc
4
 
5
  import torch
6
+ from transformers import LlamaForCausalLM, LlamaTokenizer
7
  from peft import PeftModel
 
8
 
9
  from .globals import Global
10
 
 
22
  pass
23
 
24
 
25
+ def get_new_base_model(base_model_name):
26
+ if Global.ui_dev_mode:
27
+ return
 
 
 
28
 
29
+ if Global.new_base_model_that_is_ready_to_be_used:
30
+ if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
31
+ model = Global.new_base_model_that_is_ready_to_be_used
32
+ Global.new_base_model_that_is_ready_to_be_used = None
33
+ Global.name_of_new_base_model_that_is_ready_to_be_used = None
34
+ return model
35
+ else:
36
+ Global.new_base_model_that_is_ready_to_be_used = None
37
+ Global.name_of_new_base_model_that_is_ready_to_be_used = None
38
+ clear_cache()
39
 
40
+ device = get_device()
 
41
 
42
  if device == "cuda":
43
+ model = LlamaForCausalLM.from_pretrained(
44
+ base_model_name,
45
+ load_in_8bit=Global.load_8bit,
46
  torch_dtype=torch.float16,
47
+ # device_map="auto",
48
+ # ? https://github.com/tloen/alpaca-lora/issues/21
49
+ device_map={'': 0},
50
  )
51
  elif device == "mps":
52
+ model = LlamaForCausalLM.from_pretrained(
53
+ base_model_name,
 
54
  device_map={"": device},
55
  torch_dtype=torch.float16,
56
  )
57
  else:
58
+ model = LlamaForCausalLM.from_pretrained(
59
+ base_model_name, device_map={"": device}, low_cpu_mem_usage=True
 
 
60
  )
61
 
62
+ model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
63
  model.config.bos_token_id = 1
64
  model.config.eos_token_id = 2
65
 
 
 
 
 
 
 
66
  return model
67
 
68
 
69
+ def get_tokenizer(base_model_name):
70
+ if Global.ui_dev_mode:
71
+ return
72
+
73
+ loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
74
+ if loaded_tokenizer:
75
+ return loaded_tokenizer
76
+
77
+ tokenizer = LlamaTokenizer.from_pretrained(base_model_name)
78
+ Global.loaded_tokenizers.set(base_model_name, tokenizer)
79
 
80
+ return tokenizer
81
 
82
+
83
+ def get_model(
84
+ base_model_name,
85
+ peft_model_name=None):
86
  if Global.ui_dev_mode:
87
  return
88
 
89
+ if peft_model_name == "None":
90
+ peft_model_name = None
91
+
92
+ model_key = base_model_name
93
+ if peft_model_name:
94
+ model_key = f"{base_model_name}//{peft_model_name}"
95
+
96
+ loaded_model = Global.loaded_models.get(model_key)
97
+ if loaded_model:
98
+ return loaded_model
99
+
100
+ peft_model_name_or_path = peft_model_name
101
+
102
+ lora_models_directory_path = os.path.join(Global.data_dir, "lora_models")
103
+ possible_lora_model_path = os.path.join(
104
+ lora_models_directory_path, peft_model_name)
105
+ if os.path.isdir(possible_lora_model_path):
106
+ peft_model_name_or_path = possible_lora_model_path
107
+
108
+ Global.loaded_models.prepare_to_set()
109
+ clear_cache()
110
+
111
+ model = get_new_base_model(base_model_name)
112
+
113
+ if peft_model_name:
114
+ device = get_device()
115
+
116
  if device == "cuda":
117
+ model = PeftModel.from_pretrained(
118
+ model,
119
+ peft_model_name_or_path,
120
  torch_dtype=torch.float16,
121
+ # ? https://github.com/tloen/alpaca-lora/issues/21
122
+ device_map={'': 0},
123
  )
124
  elif device == "mps":
125
+ model = PeftModel.from_pretrained(
126
+ model,
127
+ peft_model_name_or_path,
128
  device_map={"": device},
129
  torch_dtype=torch.float16,
130
  )
131
  else:
132
+ model = PeftModel.from_pretrained(
133
+ model,
134
+ peft_model_name_or_path,
135
+ device_map={"": device},
136
  )
137
 
138
+ model.config.pad_token_id = get_tokenizer(base_model_name).pad_token_id = 0
139
+ model.config.bos_token_id = 1
140
+ model.config.eos_token_id = 2
141
+
142
+ if not Global.load_8bit:
143
+ model.half() # seems to fix bugs for some users.
144
+
145
+ model.eval()
146
+ if torch.__version__ >= "2" and sys.platform != "win32":
147
+ model = torch.compile(model)
148
+
149
+ Global.loaded_models.set(model_key, model)
150
+ clear_cache()
151
+
152
+ return model
153
+
154
+
155
+ def prepare_base_model(base_model_name=Global.default_base_model_name):
156
+ Global.new_base_model_that_is_ready_to_be_used = get_new_base_model(base_model_name)
157
+ Global.name_of_new_base_model_that_is_ready_to_be_used = base_model_name
158
+
159
 
160
  def clear_cache():
161
  gc.collect()
 
166
 
167
 
168
  def unload_models():
169
+ Global.loaded_models.clear()
170
+ Global.loaded_tokenizers.clear()
 
 
 
 
171
  clear_cache()
 
 
 
 
 
 
 
llama_lora/ui/finetune_ui.py CHANGED
@@ -10,8 +10,8 @@ from transformers import TrainerCallback
10
 
11
  from ..globals import Global
12
  from ..models import (
13
- get_base_model, get_tokenizer,
14
- clear_cache, unload_models_if_already_used)
15
  from ..utils.data import (
16
  get_available_template_names,
17
  get_available_dataset_names,
@@ -258,22 +258,30 @@ def do_train(
258
  dataset_plain_text_data_separator,
259
  # Training Options
260
  max_seq_length,
 
261
  micro_batch_size,
262
  gradient_accumulation_steps,
263
  epochs,
264
  learning_rate,
 
265
  lora_r,
266
  lora_alpha,
267
  lora_dropout,
 
268
  model_name,
269
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
270
  ):
271
  try:
272
- clear_cache()
273
- # If model has been used in inference, we need to unload it first.
274
- # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
275
- # gradient at index 1 - expected device meta but got cuda:0' error.
276
- unload_models_if_already_used()
 
 
 
 
 
277
 
278
  prompter = Prompter(template)
279
  variable_names = prompter.get_variable_names()
@@ -319,6 +327,7 @@ def do_train(
319
  data = process_json_dataset(data)
320
 
321
  data_count = len(data)
 
322
 
323
  train_data = [
324
  {
@@ -356,13 +365,16 @@ def do_train(
356
 
357
  Train options: {json.dumps({
358
  'max_seq_length': max_seq_length,
 
359
  'micro_batch_size': micro_batch_size,
360
  'gradient_accumulation_steps': gradient_accumulation_steps,
361
  'epochs': epochs,
362
  'learning_rate': learning_rate,
 
363
  'lora_r': lora_r,
364
  'lora_alpha': lora_alpha,
365
  'lora_dropout': lora_dropout,
 
366
  'model_name': model_name,
367
  }, indent=2)}
368
 
@@ -373,6 +385,9 @@ Train data (first 10):
373
  time.sleep(2)
374
  return message
375
 
 
 
 
376
  log_history = []
377
 
378
  class UiTrainerCallback(TrainerCallback):
@@ -409,21 +424,51 @@ Train data (first 10):
409
 
410
  Global.should_stop_training = False
411
 
412
- # Do this again right before training to make sure the model is not used in inference.
413
- unload_models_if_already_used()
414
- clear_cache()
415
-
416
- base_model = get_base_model()
417
- tokenizer = get_tokenizer()
418
 
419
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
420
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
421
 
422
- results = Global.train_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  base_model, # base_model
424
  tokenizer, # tokenizer
425
- os.path.join(Global.data_dir, "lora_models",
426
- model_name), # output_dir
427
  train_data,
428
  # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
429
  micro_batch_size, # micro_batch_size
@@ -431,12 +476,12 @@ Train data (first 10):
431
  epochs, # num_epochs
432
  learning_rate, # learning_rate
433
  max_seq_length, # cutoff_len
434
- 0, # val_set_size
435
  lora_r, # lora_r
436
  lora_alpha, # lora_alpha
437
  lora_dropout, # lora_dropout
438
- ["q_proj", "v_proj"], # lora_target_modules
439
- True, # train_on_inputs
440
  False, # group_by_length
441
  None, # resume_from_checkpoint
442
  training_callbacks # callbacks
@@ -445,12 +490,17 @@ Train data (first 10):
445
  logs_str = "\n".join([json.dumps(log)
446
  for log in log_history]) or "None"
447
 
448
- result_message = f"Training ended:\n{str(results)}\n\nLogs:\n{logs_str}"
449
  print(result_message)
 
 
 
 
 
450
  return result_message
451
 
452
  except Exception as e:
453
- raise gr.Error(e)
454
 
455
 
456
  def do_abort_training():
@@ -595,11 +645,20 @@ def finetune_ui():
595
  )
596
  )
597
 
598
- max_seq_length = gr.Slider(
599
- minimum=1, maximum=4096, value=512,
600
- label="Max Sequence Length",
601
- info="The maximum length of each sample text sequence. Sequences longer than this will be truncated."
602
- )
 
 
 
 
 
 
 
 
 
603
 
604
  with gr.Row():
605
  micro_batch_size_default_value = 1
@@ -625,7 +684,7 @@ def finetune_ui():
625
  )
626
 
627
  epochs = gr.Slider(
628
- minimum=1, maximum=100, step=1, value=3,
629
  label="Epochs",
630
  info="The number of times to iterate over the entire training dataset. A larger number of epochs may improve model performance but also increase the risk of overfitting.")
631
 
@@ -635,6 +694,12 @@ def finetune_ui():
635
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
636
  )
637
 
 
 
 
 
 
 
638
  with gr.Column():
639
  lora_r = gr.Slider(
640
  minimum=1, maximum=16, step=1, value=8,
@@ -654,6 +719,12 @@ def finetune_ui():
654
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
655
  )
656
 
 
 
 
 
 
 
657
  with gr.Column():
658
  model_name = gr.Textbox(
659
  lines=1, label="LoRA Model Name", value=random_name,
@@ -675,25 +746,28 @@ def finetune_ui():
675
  elem_id="finetune_confirm_stop_btn"
676
  )
677
 
678
- training_status = gr.Text(
679
- "Training status will be shown here.",
680
- label="Training Status/Results",
681
  elem_id="finetune_training_status")
682
 
683
  train_progress = train_btn.click(
684
  fn=do_train,
685
  inputs=(dataset_inputs + [
686
  max_seq_length,
 
687
  micro_batch_size,
688
  gradient_accumulation_steps,
689
  epochs,
690
  learning_rate,
 
691
  lora_r,
692
  lora_alpha,
693
  lora_dropout,
 
694
  model_name
695
  ]),
696
- outputs=training_status
697
  )
698
 
699
  # controlled by JS, shows the confirm_abort_button
@@ -811,6 +885,12 @@ def finetune_ui():
811
  document.getElementById('finetune_confirm_stop_btn').style.display =
812
  'none';
813
  }, 5000);
 
 
 
 
 
 
814
  document.getElementById('finetune_stop_btn').style.display = 'none';
815
  document.getElementById('finetune_confirm_stop_btn').style.display =
816
  'block';
 
10
 
11
  from ..globals import Global
12
  from ..models import (
13
+ get_new_base_model, get_tokenizer,
14
+ clear_cache, unload_models)
15
  from ..utils.data import (
16
  get_available_template_names,
17
  get_available_dataset_names,
 
258
  dataset_plain_text_data_separator,
259
  # Training Options
260
  max_seq_length,
261
+ evaluate_data_percentage,
262
  micro_batch_size,
263
  gradient_accumulation_steps,
264
  epochs,
265
  learning_rate,
266
+ train_on_inputs,
267
  lora_r,
268
  lora_alpha,
269
  lora_dropout,
270
+ lora_target_modules,
271
  model_name,
272
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
273
  ):
274
  try:
275
+ base_model_name = Global.default_base_model_name
276
+ output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
277
+ if os.path.exists(output_dir):
278
+ if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
279
+ raise ValueError(f"The output directory already exists and is not empty. ({output_dir})")
280
+
281
+ if not should_training_progress_track_tqdm:
282
+ progress(0, desc="Preparing train data...")
283
+
284
+ unload_models() # Need RAM for training
285
 
286
  prompter = Prompter(template)
287
  variable_names = prompter.get_variable_names()
 
327
  data = process_json_dataset(data)
328
 
329
  data_count = len(data)
330
+ evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
331
 
332
  train_data = [
333
  {
 
365
 
366
  Train options: {json.dumps({
367
  'max_seq_length': max_seq_length,
368
+ 'val_set_size': evaluate_data_count,
369
  'micro_batch_size': micro_batch_size,
370
  'gradient_accumulation_steps': gradient_accumulation_steps,
371
  'epochs': epochs,
372
  'learning_rate': learning_rate,
373
+ 'train_on_inputs': train_on_inputs,
374
  'lora_r': lora_r,
375
  'lora_alpha': lora_alpha,
376
  'lora_dropout': lora_dropout,
377
+ 'lora_target_modules': lora_target_modules,
378
  'model_name': model_name,
379
  }, indent=2)}
380
 
 
385
  time.sleep(2)
386
  return message
387
 
388
+ if not should_training_progress_track_tqdm:
389
+ progress(0, desc="Preparing model for training...")
390
+
391
  log_history = []
392
 
393
  class UiTrainerCallback(TrainerCallback):
 
424
 
425
  Global.should_stop_training = False
426
 
427
+ base_model = get_new_base_model(base_model_name)
428
+ tokenizer = get_tokenizer(base_model_name)
 
 
 
 
429
 
430
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
431
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
432
 
433
+ if not os.path.exists(output_dir):
434
+ os.makedirs(output_dir)
435
+
436
+ with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
437
+ dataset_name = "N/A (from text input)"
438
+ if load_dataset_from == "Data Dir":
439
+ dataset_name = dataset_from_data_dir
440
+
441
+ info = {
442
+ 'base_model': base_model_name,
443
+ 'prompt_template': template,
444
+ 'dataset_name': dataset_name,
445
+ 'dataset_rows': len(train_data),
446
+ 'timestamp': time.time(),
447
+
448
+ 'max_seq_length': max_seq_length,
449
+ 'train_on_inputs': train_on_inputs,
450
+
451
+ 'micro_batch_size': micro_batch_size,
452
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
453
+ 'epochs': epochs,
454
+ 'learning_rate': learning_rate,
455
+
456
+ 'evaluate_data_percentage': evaluate_data_percentage,
457
+
458
+ 'lora_r': lora_r,
459
+ 'lora_alpha': lora_alpha,
460
+ 'lora_dropout': lora_dropout,
461
+ 'lora_target_modules': lora_target_modules,
462
+ }
463
+ json.dump(info, info_json_file, indent=2)
464
+
465
+ if not should_training_progress_track_tqdm:
466
+ progress(0, desc="Train starting...")
467
+
468
+ train_output = Global.train_fn(
469
  base_model, # base_model
470
  tokenizer, # tokenizer
471
+ output_dir, # output_dir
 
472
  train_data,
473
  # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
474
  micro_batch_size, # micro_batch_size
 
476
  epochs, # num_epochs
477
  learning_rate, # learning_rate
478
  max_seq_length, # cutoff_len
479
+ evaluate_data_count, # val_set_size
480
  lora_r, # lora_r
481
  lora_alpha, # lora_alpha
482
  lora_dropout, # lora_dropout
483
+ lora_target_modules, # lora_target_modules
484
+ train_on_inputs, # train_on_inputs
485
  False, # group_by_length
486
  None, # resume_from_checkpoint
487
  training_callbacks # callbacks
 
490
  logs_str = "\n".join([json.dumps(log)
491
  for log in log_history]) or "None"
492
 
493
+ result_message = f"Training ended:\n{str(train_output)}\n\nLogs:\n{logs_str}"
494
  print(result_message)
495
+
496
+ del base_model
497
+ del tokenizer
498
+ clear_cache()
499
+
500
  return result_message
501
 
502
  except Exception as e:
503
+ raise gr.Error(f"{e} (To dismiss this error, click the 'Abort' button)")
504
 
505
 
506
  def do_abort_training():
 
645
  )
646
  )
647
 
648
+ with gr.Row():
649
+ max_seq_length = gr.Slider(
650
+ minimum=1, maximum=4096, value=512,
651
+ label="Max Sequence Length",
652
+ info="The maximum length of each sample text sequence. Sequences longer than this will be truncated.",
653
+ elem_id="finetune_max_seq_length"
654
+ )
655
+
656
+ train_on_inputs = gr.Checkbox(
657
+ label="Train on Inputs",
658
+ value=True,
659
+ info="If not enabled, inputs will be masked out in loss.",
660
+ elem_id="finetune_train_on_inputs"
661
+ )
662
 
663
  with gr.Row():
664
  micro_batch_size_default_value = 1
 
684
  )
685
 
686
  epochs = gr.Slider(
687
+ minimum=1, maximum=100, step=1, value=10,
688
  label="Epochs",
689
  info="The number of times to iterate over the entire training dataset. A larger number of epochs may improve model performance but also increase the risk of overfitting.")
690
 
 
694
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
695
  )
696
 
697
+ evaluate_data_percentage = gr.Slider(
698
+ minimum=0, maximum=0.5, step=0.001, value=0.03,
699
+ label="Evaluation Data Percentage",
700
+ info="The percentage of data to be used for evaluation. This percentage of data will not be used for training and will be used to assess the performance of the model during the process."
701
+ )
702
+
703
  with gr.Column():
704
  lora_r = gr.Slider(
705
  minimum=1, maximum=16, step=1, value=8,
 
719
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
720
  )
721
 
722
+ lora_target_modules = gr.CheckboxGroup(
723
+ label="LoRA Target Modules",
724
+ choices=["q_proj", "k_proj", "v_proj", "o_proj"],
725
+ value=["q_proj", "v_proj"],
726
+ )
727
+
728
  with gr.Column():
729
  model_name = gr.Textbox(
730
  lines=1, label="LoRA Model Name", value=random_name,
 
746
  elem_id="finetune_confirm_stop_btn"
747
  )
748
 
749
+ train_output = gr.Text(
750
+ "Training results will be shown here.",
751
+ label="Train Output",
752
  elem_id="finetune_training_status")
753
 
754
  train_progress = train_btn.click(
755
  fn=do_train,
756
  inputs=(dataset_inputs + [
757
  max_seq_length,
758
+ evaluate_data_percentage,
759
  micro_batch_size,
760
  gradient_accumulation_steps,
761
  epochs,
762
  learning_rate,
763
+ train_on_inputs,
764
  lora_r,
765
  lora_alpha,
766
  lora_dropout,
767
+ lora_target_modules,
768
  model_name
769
  ]),
770
+ outputs=train_output
771
  )
772
 
773
  # controlled by JS, shows the confirm_abort_button
 
885
  document.getElementById('finetune_confirm_stop_btn').style.display =
886
  'none';
887
  }, 5000);
888
+ document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
889
+ 'none';
890
+ setTimeout(function () {
891
+ document.getElementById('finetune_confirm_stop_btn').style['pointer-events'] =
892
+ 'inherit';
893
+ }, 300);
894
  document.getElementById('finetune_stop_btn').style.display = 'none';
895
  document.getElementById('finetune_confirm_stop_btn').style.display =
896
  'block';
llama_lora/ui/inference_ui.py CHANGED
@@ -7,17 +7,30 @@ import transformers
7
  from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
- from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_device
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
- get_path_of_available_lora_model)
15
  from ..utils.prompter import Prompter
16
  from ..utils.callbacks import Iteratorize, Stream
17
 
18
  device = get_device()
19
 
20
  default_show_raw = True
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def do_inference(
@@ -35,20 +48,25 @@ def do_inference(
35
  show_raw=False,
36
  progress=gr.Progress(track_tqdm=True),
37
  ):
 
 
38
  try:
 
 
 
 
 
 
 
 
 
39
  variables = [variable_0, variable_1, variable_2, variable_3,
40
  variable_4, variable_5, variable_6, variable_7]
41
  prompter = Prompter(prompt_template)
42
  prompt = prompter.generate_prompt(variables)
43
 
44
- if lora_model_name is not None and "/" not in lora_model_name and lora_model_name != "None":
45
- path_of_available_lora_model = get_path_of_available_lora_model(
46
- lora_model_name)
47
- if path_of_available_lora_model:
48
- lora_model_name = path_of_available_lora_model
49
-
50
  if Global.ui_dev_mode:
51
- message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {Global.base_model}\nLoRA model: {lora_model_name}\n\nThe following text is your prompt:\n\n{prompt}"
52
  print(message)
53
 
54
  if stream_output:
@@ -66,18 +84,24 @@ def do_inference(
66
  yield out
67
 
68
  for partial_sentence in word_generator(message):
69
- yield partial_sentence, json.dumps(list(range(len(partial_sentence.split()))), indent=2)
 
 
 
 
 
70
  time.sleep(0.05)
71
 
72
  return
73
  time.sleep(1)
74
- yield message, json.dumps(list(range(len(message.split()))), indent=2)
 
 
 
75
  return
76
 
77
- model = get_base_model()
78
- if not lora_model_name == "None" and lora_model_name is not None:
79
- model = get_model_with_lora(lora_model_name)
80
- tokenizer = get_tokenizer()
81
 
82
  inputs = tokenizer(prompt, return_tensors="pt")
83
  input_ids = inputs["input_ids"].to(device)
@@ -97,6 +121,19 @@ def do_inference(
97
  "max_new_tokens": max_new_tokens,
98
  }
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if stream_output:
101
  # Stream the reply 1 token at a time.
102
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
@@ -128,29 +165,60 @@ def do_inference(
128
  raw_output = None
129
  if show_raw:
130
  raw_output = str(output)
131
- yield prompter.get_response(decoded_output), raw_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return # early return for stream_output
133
 
134
  # Without streaming
135
  with torch.no_grad():
136
- generation_output = model.generate(
137
- input_ids=input_ids,
138
- generation_config=generation_config,
139
- return_dict_in_generate=True,
140
- output_scores=True,
141
- max_new_tokens=max_new_tokens,
142
- )
143
  s = generation_output.sequences[0]
144
  output = tokenizer.decode(s)
145
  raw_output = None
146
  if show_raw:
147
  raw_output = str(s)
148
- yield prompter.get_response(output), raw_output
 
 
 
 
 
 
 
149
 
150
  except Exception as e:
151
  raise gr.Error(e)
152
 
153
 
 
 
 
 
 
154
  def reload_selections(current_lora_model, current_prompt_template):
155
  available_template_names = get_available_template_names()
156
  available_template_names_with_none = available_template_names + ["None"]
@@ -172,7 +240,7 @@ def reload_selections(current_lora_model, current_prompt_template):
172
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
173
 
174
 
175
- def handle_prompt_template_change(prompt_template):
176
  prompter = Prompter(prompt_template)
177
  var_names = prompter.get_variable_names()
178
  human_var_names = [' '.join(word.capitalize()
@@ -182,7 +250,36 @@ def handle_prompt_template_change(prompt_template):
182
  while len(gr_updates) < 8:
183
  gr_updates.append(gr.Textbox.update(
184
  label="Not Used", visible=False))
185
- return gr_updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
 
188
  def update_prompt_preview(prompt_template,
@@ -200,12 +297,15 @@ def inference_ui():
200
 
201
  with gr.Blocks() as inference_ui_blocks:
202
  with gr.Row():
203
- lora_model = gr.Dropdown(
204
- label="LoRA Model",
205
- elem_id="inference_lora_model",
206
- value="tloen/alpaca-lora-7b",
207
- allow_custom_value=True,
208
- )
 
 
 
209
  prompt_template = gr.Dropdown(
210
  label="Prompt Template",
211
  elem_id="inference_prompt_template",
@@ -278,7 +378,7 @@ def inference_ui():
278
  )
279
 
280
  num_beams = gr.Slider(
281
- minimum=1, maximum=4, value=1, step=1,
282
  label="Beams",
283
  elem_id="inference_beams"
284
  )
@@ -318,7 +418,7 @@ def inference_ui():
318
  with gr.Column(elem_id="inference_output_group_container"):
319
  with gr.Column(elem_id="inference_output_group"):
320
  inference_output = gr.Textbox(
321
- lines=12, label="Output", elem_id="inference_output")
322
  inference_output.style(show_copy_button=True)
323
  with gr.Accordion(
324
  "Raw Output",
@@ -346,11 +446,25 @@ def inference_ui():
346
  )
347
  things_that_might_timeout.append(reload_selections_event)
348
 
349
- prompt_template_change_event = prompt_template.change(fn=handle_prompt_template_change, inputs=[prompt_template], outputs=[
350
- variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
 
 
 
 
351
  things_that_might_timeout.append(prompt_template_change_event)
352
 
 
 
 
 
 
 
353
  generate_event = generate_btn.click(
 
 
 
 
354
  fn=do_inference,
355
  inputs=[
356
  lora_model,
@@ -369,8 +483,12 @@ def inference_ui():
369
  outputs=[inference_output, inference_raw_output],
370
  api_name="inference"
371
  )
372
- stop_btn.click(fn=None, inputs=None, outputs=None,
373
- cancels=[generate_event])
 
 
 
 
374
 
375
  update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
376
  variable_0, variable_1, variable_2, variable_3,
@@ -543,9 +661,15 @@ def inference_ui():
543
  return function (...args) {
544
  const context = this;
545
  clearTimeout(timeout);
546
- timeout = setTimeout(() => {
 
 
 
 
 
547
  func.apply(context, args);
548
- }, wait);
 
549
  };
550
  }
551
 
@@ -580,5 +704,27 @@ def inference_ui():
580
  });
581
  }
582
  }, 100);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  }
584
  """)
 
7
  from transformers import GenerationConfig
8
 
9
  from ..globals import Global
10
+ from ..models import get_model, get_tokenizer, get_device
11
  from ..utils.data import (
12
  get_available_template_names,
13
  get_available_lora_model_names,
14
+ get_info_of_available_lora_model)
15
  from ..utils.prompter import Prompter
16
  from ..utils.callbacks import Iteratorize, Stream
17
 
18
  device = get_device()
19
 
20
  default_show_raw = True
21
+ inference_output_lines = 12
22
+
23
+
24
+ def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
25
+ base_model_name = Global.default_base_model_name
26
+
27
+ try:
28
+ get_tokenizer(base_model_name)
29
+ get_model(base_model_name, lora_model_name)
30
+ return ("", "")
31
+
32
+ except Exception as e:
33
+ raise gr.Error(e)
34
 
35
 
36
  def do_inference(
 
48
  show_raw=False,
49
  progress=gr.Progress(track_tqdm=True),
50
  ):
51
+ base_model_name = Global.default_base_model_name
52
+
53
  try:
54
+ if Global.generation_force_stopped_at is not None:
55
+ required_elapsed_time_after_forced_stop = 1
56
+ current_unix_time = time.time()
57
+ remaining_time = required_elapsed_time_after_forced_stop - \
58
+ (current_unix_time - Global.generation_force_stopped_at)
59
+ if remaining_time > 0:
60
+ time.sleep(remaining_time)
61
+ Global.generation_force_stopped_at = None
62
+
63
  variables = [variable_0, variable_1, variable_2, variable_3,
64
  variable_4, variable_5, variable_6, variable_7]
65
  prompter = Prompter(prompt_template)
66
  prompt = prompter.generate_prompt(variables)
67
 
 
 
 
 
 
 
68
  if Global.ui_dev_mode:
69
+ message = f"Hi, I’m currently in UI-development mode and do not have access to resources to process your request. However, this behavior is similar to what will actually happen, so you can try and see how it will work!\n\nBase model: {base_model_name}\nLoRA model: {lora_model_name}\n\nThe following is your prompt:\n\n{prompt}"
70
  print(message)
71
 
72
  if stream_output:
 
84
  yield out
85
 
86
  for partial_sentence in word_generator(message):
87
+ yield (
88
+ gr.Textbox.update(
89
+ value=partial_sentence, lines=inference_output_lines),
90
+ json.dumps(
91
+ list(range(len(partial_sentence.split()))), indent=2)
92
+ )
93
  time.sleep(0.05)
94
 
95
  return
96
  time.sleep(1)
97
+ yield (
98
+ gr.Textbox.update(value=message, lines=inference_output_lines),
99
+ json.dumps(list(range(len(message.split()))), indent=2)
100
+ )
101
  return
102
 
103
+ tokenizer = get_tokenizer(base_model_name)
104
+ model = get_model(base_model_name, lora_model_name)
 
 
105
 
106
  inputs = tokenizer(prompt, return_tensors="pt")
107
  input_ids = inputs["input_ids"].to(device)
 
121
  "max_new_tokens": max_new_tokens,
122
  }
123
 
124
+ def ui_generation_stopping_criteria(input_ids, score, **kwargs):
125
+ if Global.should_stop_generating:
126
+ return True
127
+ return False
128
+
129
+ Global.should_stop_generating = False
130
+ generate_params.setdefault(
131
+ "stopping_criteria", transformers.StoppingCriteriaList()
132
+ )
133
+ generate_params["stopping_criteria"].append(
134
+ ui_generation_stopping_criteria
135
+ )
136
+
137
  if stream_output:
138
  # Stream the reply 1 token at a time.
139
  # This is based on the trick of using 'stopping_criteria' to create an iterator,
 
165
  raw_output = None
166
  if show_raw:
167
  raw_output = str(output)
168
+ response = prompter.get_response(decoded_output)
169
+
170
+ if Global.should_stop_generating:
171
+ return
172
+
173
+ yield (
174
+ gr.Textbox.update(
175
+ value=response, lines=inference_output_lines),
176
+ raw_output)
177
+
178
+ if Global.should_stop_generating:
179
+ # If the user stops the generation, and then clicks the
180
+ # generation button again, they may mysteriously landed
181
+ # here, in the previous, should-be-stopped generation
182
+ # function call, with the new generation function not be
183
+ # called at all. To workaround this, we yield a message
184
+ # and setting lines=1, and if the front-end JS detects
185
+ # that lines has been set to 1 (rows="1" in HTML),
186
+ # it will automatically click the generate button again
187
+ # (gr.Textbox.update() does not support updating
188
+ # elem_classes or elem_id).
189
+ # [WORKAROUND-UI01]
190
+ yield (
191
+ gr.Textbox.update(
192
+ value="Please retry", lines=1),
193
+ None)
194
  return # early return for stream_output
195
 
196
  # Without streaming
197
  with torch.no_grad():
198
+ generation_output = model.generate(**generate_params)
 
 
 
 
 
 
199
  s = generation_output.sequences[0]
200
  output = tokenizer.decode(s)
201
  raw_output = None
202
  if show_raw:
203
  raw_output = str(s)
204
+
205
+ response = prompter.get_response(output)
206
+ if Global.should_stop_generating:
207
+ return
208
+
209
+ yield (
210
+ gr.Textbox.update(value=response, lines=inference_output_lines),
211
+ raw_output)
212
 
213
  except Exception as e:
214
  raise gr.Error(e)
215
 
216
 
217
+ def handle_stop_generate():
218
+ Global.generation_force_stopped_at = time.time()
219
+ Global.should_stop_generating = True
220
+
221
+
222
  def reload_selections(current_lora_model, current_prompt_template):
223
  available_template_names = get_available_template_names()
224
  available_template_names_with_none = available_template_names + ["None"]
 
240
  gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
241
 
242
 
243
+ def handle_prompt_template_change(prompt_template, lora_model):
244
  prompter = Prompter(prompt_template)
245
  var_names = prompter.get_variable_names()
246
  human_var_names = [' '.join(word.capitalize()
 
250
  while len(gr_updates) < 8:
251
  gr_updates.append(gr.Textbox.update(
252
  label="Not Used", visible=False))
253
+
254
+ model_prompt_template_message_update = gr.Markdown.update(
255
+ "", visible=False)
256
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
257
+ if lora_mode_info and isinstance(lora_mode_info, dict):
258
+ model_prompt_template = lora_mode_info.get("prompt_template")
259
+ if model_prompt_template and model_prompt_template != prompt_template:
260
+ model_prompt_template_message_update = gr.Markdown.update(
261
+ f"This model was trained with prompt template `{model_prompt_template}`.", visible=True)
262
+
263
+ return [model_prompt_template_message_update] + gr_updates
264
+
265
+
266
+ def handle_lora_model_change(lora_model, prompt_template):
267
+ lora_mode_info = get_info_of_available_lora_model(lora_model)
268
+ if not lora_mode_info:
269
+ return gr.Markdown.update("", visible=False), prompt_template
270
+
271
+ if not isinstance(lora_mode_info, dict):
272
+ return gr.Markdown.update("", visible=False), prompt_template
273
+
274
+ model_prompt_template = lora_mode_info.get("prompt_template")
275
+ if not model_prompt_template:
276
+ return gr.Markdown.update("", visible=False), prompt_template
277
+
278
+ available_template_names = get_available_template_names()
279
+ if model_prompt_template in available_template_names:
280
+ return gr.Markdown.update("", visible=False), model_prompt_template
281
+
282
+ return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
283
 
284
 
285
  def update_prompt_preview(prompt_template,
 
297
 
298
  with gr.Blocks() as inference_ui_blocks:
299
  with gr.Row():
300
+ with gr.Column(elem_id="inference_lora_model_group"):
301
+ model_prompt_template_message = gr.Markdown(
302
+ "", visible=False, elem_id="inference_lora_model_prompt_template_message")
303
+ lora_model = gr.Dropdown(
304
+ label="LoRA Model",
305
+ elem_id="inference_lora_model",
306
+ value="tloen/alpaca-lora-7b",
307
+ allow_custom_value=True,
308
+ )
309
  prompt_template = gr.Dropdown(
310
  label="Prompt Template",
311
  elem_id="inference_prompt_template",
 
378
  )
379
 
380
  num_beams = gr.Slider(
381
+ minimum=1, maximum=5, value=2, step=1,
382
  label="Beams",
383
  elem_id="inference_beams"
384
  )
 
418
  with gr.Column(elem_id="inference_output_group_container"):
419
  with gr.Column(elem_id="inference_output_group"):
420
  inference_output = gr.Textbox(
421
+ lines=inference_output_lines, label="Output", elem_id="inference_output")
422
  inference_output.style(show_copy_button=True)
423
  with gr.Accordion(
424
  "Raw Output",
 
446
  )
447
  things_that_might_timeout.append(reload_selections_event)
448
 
449
+ prompt_template_change_event = prompt_template.change(
450
+ fn=handle_prompt_template_change,
451
+ inputs=[prompt_template, lora_model],
452
+ outputs=[
453
+ model_prompt_template_message,
454
+ variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
455
  things_that_might_timeout.append(prompt_template_change_event)
456
 
457
+ lora_model_change_event = lora_model.change(
458
+ fn=handle_lora_model_change,
459
+ inputs=[lora_model, prompt_template],
460
+ outputs=[model_prompt_template_message, prompt_template])
461
+ things_that_might_timeout.append(lora_model_change_event)
462
+
463
  generate_event = generate_btn.click(
464
+ fn=prepare_inference,
465
+ inputs=[lora_model],
466
+ outputs=[inference_output, inference_raw_output],
467
+ ).then(
468
  fn=do_inference,
469
  inputs=[
470
  lora_model,
 
483
  outputs=[inference_output, inference_raw_output],
484
  api_name="inference"
485
  )
486
+ stop_btn.click(
487
+ fn=handle_stop_generate,
488
+ inputs=None,
489
+ outputs=None,
490
+ cancels=[generate_event]
491
+ )
492
 
493
  update_prompt_preview_event = update_prompt_preview_btn.click(fn=update_prompt_preview, inputs=[prompt_template,
494
  variable_0, variable_1, variable_2, variable_3,
 
661
  return function (...args) {
662
  const context = this;
663
  clearTimeout(timeout);
664
+ const fn = () => {
665
+ if (document.querySelector('#inference_preview_prompt > .wrap:not(.hide)')) {
666
+ // Preview request is still loading, wait for 10ms and try again.
667
+ timeout = setTimeout(fn, 10);
668
+ return;
669
+ }
670
  func.apply(context, args);
671
+ };
672
+ timeout = setTimeout(fn, wait);
673
  };
674
  }
675
 
 
704
  });
705
  }
706
  }, 100);
707
+
708
+ // [WORKAROUND-UI01]
709
+ setTimeout(function () {
710
+ const inference_output_textarea = document.querySelector(
711
+ '#inference_output textarea'
712
+ );
713
+ if (!inference_output_textarea) return;
714
+ const observer = new MutationObserver(function () {
715
+ if (inference_output_textarea.getAttribute('rows') === '1') {
716
+ setTimeout(function () {
717
+ const inference_generate_btn = document.getElementById(
718
+ 'inference_generate_btn'
719
+ );
720
+ if (inference_generate_btn) inference_generate_btn.click();
721
+ }, 10);
722
+ }
723
+ });
724
+ observer.observe(inference_output_textarea, {
725
+ attributes: true,
726
+ attributeFilter: ['rows'],
727
+ });
728
+ }, 100);
729
  }
730
  """)
llama_lora/ui/main_page.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
 
3
  from ..globals import Global
4
- from ..models import get_model_with_lora
5
 
6
  from .inference_ui import inference_ui
7
  from .finetune_ui import finetune_ui
@@ -30,8 +29,8 @@ def main_page():
30
  tokenizer_ui()
31
  info = []
32
  if Global.version:
33
- info.append(f"LLaMA-LoRA `{Global.version}`")
34
- info.append(f"Base model: `{Global.base_model}`")
35
  if Global.ui_show_sys_info:
36
  info.append(f"Data dir: `{Global.data_dir}`")
37
  gr.Markdown(f"""
@@ -134,6 +133,41 @@ def main_page_custom_css():
134
  /* text-transform: uppercase; */
135
  }
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  #inference_prompt_box > *:first-child {
138
  border-bottom-left-radius: 0;
139
  border-bottom-right-radius: 0;
@@ -193,6 +227,8 @@ def main_page_custom_css():
193
  #inference_raw_output > .wrap:first-child {
194
  /* allow users to select text while generation is still in progress */
195
  pointer-events: none;
 
 
196
  }
197
 
198
  /* position sticky */
@@ -266,12 +302,16 @@ def main_page_custom_css():
266
  }
267
 
268
  @media screen and (min-width: 640px) {
269
- #inference_lora_model, #finetune_template {
 
270
  border-top-right-radius: 0;
271
  border-bottom-right-radius: 0;
272
  border-right: 0;
273
  margin-right: -16px;
274
  }
 
 
 
275
 
276
  #inference_prompt_template {
277
  border-top-left-radius: 0;
@@ -301,7 +341,7 @@ def main_page_custom_css():
301
  height: 42px !important;
302
  min-width: 42px !important;
303
  width: 42px !important;
304
- z-index: 1;
305
  }
306
  }
307
 
@@ -388,6 +428,9 @@ def main_page_custom_css():
388
  white-space: pre-wrap;
389
  }
390
 
 
 
 
391
 
392
  @media screen and (max-width: 392px) {
393
  #inference_lora_model, #finetune_template {
 
1
  import gradio as gr
2
 
3
  from ..globals import Global
 
4
 
5
  from .inference_ui import inference_ui
6
  from .finetune_ui import finetune_ui
 
29
  tokenizer_ui()
30
  info = []
31
  if Global.version:
32
+ info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
33
+ info.append(f"Base model: `{Global.default_base_model_name}`")
34
  if Global.ui_show_sys_info:
35
  info.append(f"Data dir: `{Global.data_dir}`")
36
  gr.Markdown(f"""
 
133
  /* text-transform: uppercase; */
134
  }
135
 
136
+ #inference_lora_model_group {
137
+ border-radius: var(--block-radius);
138
+ background: var(--block-background-fill);
139
+ }
140
+ #inference_lora_model_group #inference_lora_model {
141
+ background: transparent;
142
+ }
143
+ #inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
144
+ padding-bottom: 28px;
145
+ }
146
+ #inference_lora_model_group > #inference_lora_model_prompt_template_message {
147
+ position: absolute;
148
+ bottom: 8px;
149
+ left: 20px;
150
+ z-index: 1;
151
+ font-size: 12px;
152
+ opacity: 0.7;
153
+ }
154
+ #inference_lora_model_group > #inference_lora_model_prompt_template_message p {
155
+ font-size: 12px;
156
+ }
157
+ #inference_lora_model_prompt_template_message > .wrap {
158
+ display: none;
159
+ }
160
+ #inference_lora_model > .wrap:first-child:not(.hide),
161
+ #inference_prompt_template > .wrap:first-child:not(.hide) {
162
+ opacity: 0.5;
163
+ }
164
+ #inference_lora_model_group, #inference_lora_model {
165
+ z-index: 60;
166
+ }
167
+ #inference_prompt_template {
168
+ z-index: 55;
169
+ }
170
+
171
  #inference_prompt_box > *:first-child {
172
  border-bottom-left-radius: 0;
173
  border-bottom-right-radius: 0;
 
227
  #inference_raw_output > .wrap:first-child {
228
  /* allow users to select text while generation is still in progress */
229
  pointer-events: none;
230
+
231
+ padding: 12px !important;
232
  }
233
 
234
  /* position sticky */
 
302
  }
303
 
304
  @media screen and (min-width: 640px) {
305
+ #inference_lora_model, #inference_lora_model_group,
306
+ #finetune_template {
307
  border-top-right-radius: 0;
308
  border-bottom-right-radius: 0;
309
  border-right: 0;
310
  margin-right: -16px;
311
  }
312
+ #inference_lora_model_group #inference_lora_model {
313
+ box-shadow: var(--block-shadow);
314
+ }
315
 
316
  #inference_prompt_template {
317
  border-top-left-radius: 0;
 
341
  height: 42px !important;
342
  min-width: 42px !important;
343
  width: 42px !important;
344
+ z-index: 61;
345
  }
346
  }
347
 
 
428
  white-space: pre-wrap;
429
  }
430
 
431
+ #finetune_max_seq_length {
432
+ flex: 2;
433
+ }
434
 
435
  @media screen and (max-width: 392px) {
436
  #inference_lora_model, #finetune_template {
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -7,11 +7,12 @@ from ..models import get_tokenizer
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
 
10
  try:
11
  encoded_tokens = json.loads(encoded_tokens_json)
12
  if Global.ui_dev_mode:
13
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
14
- tokenizer = get_tokenizer()
15
  decoded_tokens = tokenizer.decode(encoded_tokens)
16
  return decoded_tokens, gr.Markdown.update("", visible=False)
17
  except Exception as e:
@@ -19,10 +20,11 @@ def handle_decode(encoded_tokens_json):
19
 
20
 
21
  def handle_encode(decoded_tokens):
 
22
  try:
23
  if Global.ui_dev_mode:
24
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
25
- tokenizer = get_tokenizer()
26
  result = tokenizer(decoded_tokens)
27
  encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
28
  return encoded_tokens_json, gr.Markdown.update("", visible=False)
 
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
10
+ base_model_name = Global.default_base_model_name
11
  try:
12
  encoded_tokens = json.loads(encoded_tokens_json)
13
  if Global.ui_dev_mode:
14
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
15
+ tokenizer = get_tokenizer(base_model_name)
16
  decoded_tokens = tokenizer.decode(encoded_tokens)
17
  return decoded_tokens, gr.Markdown.update("", visible=False)
18
  except Exception as e:
 
20
 
21
 
22
  def handle_encode(decoded_tokens):
23
+ base_model_name = Global.default_base_model_name
24
  try:
25
  if Global.ui_dev_mode:
26
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
27
+ tokenizer = get_tokenizer(base_model_name)
28
  result = tokenizer(decoded_tokens)
29
  encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
30
  return encoded_tokens_json, gr.Markdown.update("", visible=False)
llama_lora/utils/data.py CHANGED
@@ -52,6 +52,22 @@ def get_path_of_available_lora_model(name):
52
  return None
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def get_dataset_content(name):
56
  file_name = os.path.join(Global.data_dir, "datasets", name)
57
  if not os.path.exists(file_name):
 
52
  return None
53
 
54
 
55
+ def get_info_of_available_lora_model(name):
56
+ try:
57
+ if "/" in name:
58
+ return None
59
+ path_of_available_lora_model = get_path_of_available_lora_model(
60
+ name)
61
+ if not path_of_available_lora_model:
62
+ return None
63
+
64
+ with open(os.path.join(path_of_available_lora_model, "info.json"), "r") as json_file:
65
+ return json.load(json_file)
66
+
67
+ except Exception as e:
68
+ return None
69
+
70
+
71
  def get_dataset_content(name):
72
  file_name = os.path.join(Global.data_dir, "datasets", name)
73
  if not os.path.exists(file_name):
llama_lora/utils/lru_cache.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ class LRUCache:
5
+ def __init__(self, capacity=5):
6
+ self.cache = OrderedDict()
7
+ self.capacity = capacity
8
+
9
+ def get(self, key):
10
+ if key in self.cache:
11
+ # Move the accessed item to the end of the OrderedDict
12
+ self.cache.move_to_end(key)
13
+ return self.cache[key]
14
+ return None
15
+
16
+ def set(self, key, value):
17
+ if key in self.cache:
18
+ # If the key already exists, update its value
19
+ self.cache[key] = value
20
+ else:
21
+ # If the cache has reached its capacity, remove the least recently used item
22
+ if len(self.cache) >= self.capacity:
23
+ self.cache.popitem(last=False)
24
+ self.cache[key] = value
25
+
26
+ def clear(self):
27
+ self.cache.clear()
28
+
29
+ def prepare_to_set(self):
30
+ if len(self.cache) >= self.capacity:
31
+ self.cache.popitem(last=False)
requirements.lock.txt ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.18.0
2
+ aiofiles==23.1.0
3
+ aiohttp==3.8.4
4
+ aiosignal==1.3.1
5
+ altair==4.2.2
6
+ anyio==3.6.2
7
+ appdirs==1.4.4
8
+ asttokens==2.2.1
9
+ async-timeout==4.0.2
10
+ attrs==22.2.0
11
+ backcall==0.2.0
12
+ bitsandbytes==0.37.2
13
+ black==23.3.0
14
+ charset-normalizer==3.1.0
15
+ click==8.1.3
16
+ contourpy==1.0.7
17
+ cycler==0.11.0
18
+ datasets==2.11.0
19
+ decorator==5.1.1
20
+ dill==0.3.6
21
+ entrypoints==0.4
22
+ exceptiongroup==1.1.1
23
+ executing==1.2.0
24
+ fastapi==0.95.0
25
+ ffmpy==0.3.0
26
+ filelock==3.11.0
27
+ fire==0.5.0
28
+ fonttools==4.39.3
29
+ frozenlist==1.3.3
30
+ fsspec==2023.3.0
31
+ gradio==3.24.1
32
+ gradio_client==0.0.8
33
+ h11==0.14.0
34
+ httpcore==0.16.3
35
+ httpx==0.23.3
36
+ huggingface-hub==0.13.4
37
+ idna==3.4
38
+ importlib-metadata==6.2.0
39
+ importlib-resources==5.12.0
40
+ iniconfig==2.0.0
41
+ ipython==8.12.0
42
+ jedi==0.18.2
43
+ Jinja2==3.1.2
44
+ jsonschema==4.17.3
45
+ kiwisolver==1.4.4
46
+ linkify-it-py==2.0.0
47
+ llvmlite==0.39.1
48
+ loralib==0.1.1
49
+ markdown-it-py==2.2.0
50
+ MarkupSafe==2.1.2
51
+ matplotlib==3.7.1
52
+ matplotlib-inline==0.1.6
53
+ mdit-py-plugins==0.3.3
54
+ mdurl==0.1.2
55
+ mpmath==1.3.0
56
+ multidict==6.0.4
57
+ multiprocess==0.70.14
58
+ mypy-extensions==1.0.0
59
+ networkx==3.1
60
+ numba==0.56.4
61
+ numpy==1.23.5
62
+ nvidia-ml-py3==7.352.0
63
+ orjson==3.8.9
64
+ packaging==23.0
65
+ pandas==2.0.0
66
+ parso==0.8.3
67
+ pathspec==0.11.1
68
+ peft @ git+https://github.com/huggingface/peft.git@382b178911edff38c1ff619bbac2ba556bd2276b
69
+ pexpect==4.8.0
70
+ pickleshare==0.7.5
71
+ Pillow==9.3.0
72
+ pkgutil_resolve_name==1.3.10
73
+ platformdirs==3.2.0
74
+ pluggy==1.0.0
75
+ prompt-toolkit==3.0.38
76
+ psutil==5.9.4
77
+ ptyprocess==0.7.0
78
+ pure-eval==0.2.2
79
+ pyarrow==11.0.0
80
+ pydantic==1.10.7
81
+ pydub==0.25.1
82
+ Pygments==2.14.0
83
+ pyparsing==3.0.9
84
+ pyrsistent==0.19.3
85
+ pytest==7.2.2
86
+ python-dateutil==2.8.2
87
+ python-multipart==0.0.6
88
+ pytz==2023.3
89
+ PyYAML==6.0
90
+ Random-Word==1.0.11
91
+ regex==2023.3.23
92
+ requests==2.28.2
93
+ responses==0.18.0
94
+ rfc3986==1.5.0
95
+ semantic-version==2.10.0
96
+ sentencepiece==0.1.97
97
+ six==1.16.0
98
+ sniffio==1.3.0
99
+ stack-data==0.6.2
100
+ starlette==0.26.1
101
+ sympy==1.11.1
102
+ termcolor==2.2.0
103
+ tokenize-rt==5.0.0
104
+ tokenizers==0.13.3
105
+ tomli==2.0.1
106
+ toolz==0.12.0
107
+ torch==2.0.0
108
+ tqdm==4.65.0
109
+ traitlets==5.9.0
110
+ transformers @ git+https://github.com/huggingface/transformers.git@3f96e0b4e483c4c7d4ec9dcdc24b0b0cdf31ea5c
111
+ typing_extensions==4.5.0
112
+ tzdata==2023.3
113
+ uc-micro-py==1.0.1
114
+ urllib3==1.26.15
115
+ uvicorn==0.21.1
116
+ wcwidth==0.2.6
117
+ websockets==11.0.1
118
+ xxhash==3.2.0
119
+ yarl==1.8.2
120
+ zipp==3.15.0
templates/user_and_ai.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Unhelpful AI assistant.",
3
+ "variables": ["instruction"],
4
+ "prompt": "### User:\n{instruction}\n\n### AI:\n",
5
+ "default": "prompt",
6
+ "response_split": "### AI:"
7
+ }