Spaces:
Running on Zero

Ruurd commited on
Commit
332db3a
·
1 Parent(s): 0e0704b

Improve code structure

Browse files
Files changed (6) hide show
  1. .gitignore +86 -0
  2. app.py +21 -53
  3. infer.py +222 -0
  4. llama_diffusion_model.py +0 -93
  5. model_config.py +25 -0
  6. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compiled source #
2
+ ###################
3
+ *.com
4
+ *.class
5
+ *.dll
6
+ *.exe
7
+ *.o
8
+ *.so
9
+ *.obj
10
+ *.pyc
11
+ *.pyo
12
+ *.pyd
13
+ *.out
14
+
15
+ # Packages #
16
+ ############
17
+ *.7z
18
+ *.dmg
19
+ *.gz
20
+ *.iso
21
+ *.jar
22
+ *.rar
23
+ *.tar
24
+ *.zip
25
+
26
+ # Logs and databases #
27
+ ######################
28
+ *.log
29
+ *.sql
30
+ *.sqlite
31
+
32
+ # OS generated files #
33
+ ######################
34
+ .DS_Store
35
+ Thumbs.db
36
+ ehthumbs.db
37
+ Icon?
38
+ ._*
39
+
40
+ # Editor directories and files #
41
+ ###############################
42
+ .vscode/
43
+ .idea/
44
+ *.sublime-workspace
45
+ *.sublime-project
46
+
47
+ # Node.js #
48
+ ############
49
+ node_modules/
50
+ npm-debug.log*
51
+ yarn-debug.log*
52
+ yarn-error.log*
53
+
54
+ # Python #
55
+ ##########
56
+ __pycache__/
57
+ *.py[cod]
58
+ *.egg
59
+ *.egg-info/
60
+ dist/
61
+ build/
62
+
63
+ # C/C++ #
64
+ ##########
65
+ *.dSYM/
66
+ *.swp
67
+
68
+ # Rust #
69
+ ########
70
+ target/
71
+
72
+ # Go #
73
+ ######
74
+ bin/
75
+ vendor/
76
+
77
+ # Backup files #
78
+ ################
79
+ *~
80
+ *.bak
81
+ *.tmp
82
+
83
+ # Environment files #
84
+ #####################
85
+ .env
86
+ .env.*
app.py CHANGED
@@ -7,11 +7,25 @@ from transformers import AutoTokenizer
7
  import os
8
  import importlib
9
  from huggingface_hub import hf_hub_download
10
- from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, disable_dropout
11
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  hf_token = os.getenv("HF_TOKEN")
14
 
 
 
15
 
16
  # --- Load tokenizer ---
17
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B", use_fast=True, token=hf_token)
@@ -19,52 +33,6 @@ vocab_size = len(tokenizer)
19
  eos_token_id = tokenizer.eos_token_id
20
  mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
21
  assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
22
- # assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
23
- # def load_model():
24
- # ckpt_path = hf_hub_download(
25
- # repo_id="ruurd/tini_bi_m",
26
- # filename="diffusion-model.pth",
27
- # token=os.getenv("HF_TOKEN")
28
- # )
29
-
30
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- # model = torch.load(ckpt_path, map_location=device)
32
- # model = disable_dropout(model)
33
- # model.to(device)
34
- # model.eval()
35
- # return model
36
-
37
- def load_model():
38
- ckpt_path = hf_hub_download(
39
- repo_id="ruurd/tini_model",
40
- filename="diffusion-model-8B.pt",
41
- token=os.getenv("HF_TOKEN"),
42
- # revision="1ffb916dd34f442f87cf06dda74b96f86eaf1d15",
43
- )
44
-
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
-
47
- # Step 1: Create model from scratch
48
- model = CustomTransformerModel(CustomTransformerConfig())
49
-
50
- # Step 2: Load state_dict from full checkpoint
51
- full_model = torch.load(ckpt_path, map_location=device)
52
-
53
- # This handles both full model or just state_dict
54
- try:
55
- state_dict = full_model.state_dict()
56
- except AttributeError:
57
- state_dict = full_model # already a state_dict
58
-
59
- # Step 3: Load weights (might print mismatches)
60
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
61
- print("Missing keys:", missing)
62
- print("Unexpected keys:", unexpected)
63
-
64
- model = disable_dropout(model)
65
- model.to(device)
66
- model.eval()
67
- return model
68
 
69
  rng = np.random.default_rng()
70
 
@@ -204,11 +172,6 @@ def format_chat_prompt(question):
204
  f"{question}\n"
205
  "<|start_header_id|>assistant<|end_header_id|>\n"
206
  )
207
- # def format_chat_prompt(question):
208
- # return(
209
- # f"User:{question}\nAssistant:"
210
- # )
211
-
212
 
213
  # --- Inference Wrapper ---
214
  def diffusion_chat(question, max_it, pause_length, sharpness,
@@ -332,7 +295,12 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
332
 
333
  # --- Gradio Interface ---
334
  print("Loading model...")
335
- model = load_model()
 
 
 
 
 
336
  print("✅ Model loaded.")
337
 
338
  demo = gr.Interface(
 
7
  import os
8
  import importlib
9
  from huggingface_hub import hf_hub_download
 
10
  import spaces
11
+ from dotenv import load_dotenv
12
+ from infer import (
13
+ load_trained_model,
14
+ find_answer_start,
15
+ get_noising_schedule,
16
+ noisify_answer,
17
+ generate_diffusion_text,
18
+ filter_logits
19
+ )
20
+
21
+ # Load .env only when running locally
22
+ if os.getenv("HF_TOKEN") is None:
23
+ load_dotenv()
24
 
25
  hf_token = os.getenv("HF_TOKEN")
26
 
27
+ if hf_token is None:
28
+ raise ValueError("HF_TOKEN is not set")
29
 
30
  # --- Load tokenizer ---
31
  tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B", use_fast=True, token=hf_token)
 
33
  eos_token_id = tokenizer.eos_token_id
34
  mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
35
  assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  rng = np.random.default_rng()
38
 
 
172
  f"{question}\n"
173
  "<|start_header_id|>assistant<|end_header_id|>\n"
174
  )
 
 
 
 
 
175
 
176
  # --- Inference Wrapper ---
177
  def diffusion_chat(question, max_it, pause_length, sharpness,
 
295
 
296
  # --- Gradio Interface ---
297
  print("Loading model...")
298
+ ckpt_path = hf_hub_download(
299
+ repo_id="ruurd/tini_model",
300
+ filename="diffusion-model.pth",
301
+ token=os.getenv("HF_TOKEN")
302
+ )
303
+ model = load_trained_model(checkpoint_path=ckpt_path)
304
  print("✅ Model loaded.")
305
 
306
  demo = gr.Interface(
infer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import time
5
+ import random
6
+ import importlib
7
+ import torch.nn as nn
8
+ import os
9
+
10
+ from transformers import AutoTokenizer
11
+
12
+ rng = np.random.default_rng()
13
+
14
+ def disable_dropout(model):
15
+ for name, module in model.named_modules():
16
+ if isinstance(module, nn.Dropout):
17
+ setattr(model, name, nn.Identity()) # Replace Dropout with Identity
18
+ return model
19
+
20
+ def load_trained_model(checkpoint_path: str, base_model_name: str = "meta-llama/Llama-3.2-3B"):
21
+ # Load tokenizer + config from saved dir
22
+ hf_token = os.getenv("HF_TOKEN")
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name,
25
+ use_fast=True,
26
+ token=hf_token,
27
+ torch_dtype=torch.float32)
28
+
29
+ # Step 5: Load the model safely
30
+ model = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=False)
31
+
32
+ # Disable dropout
33
+ model = disable_dropout(model)
34
+
35
+ print("✅ Model successfully loaded from checkpoint:", checkpoint_path)
36
+
37
+ # Move to correct device
38
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
39
+ # model = model.to(torch.float32)
40
+ model.to(device)
41
+ model.eval()
42
+
43
+ return model, tokenizer
44
+
45
+ def filter_logits(logits, top_k=0, top_p=1.0, temperature=1.0):
46
+ """
47
+ Vectorized top-k and/or top-p (nucleus) filtering with temperature scaling.
48
+ Accepts logits of shape (seq_len, vocab_size) or (1, seq_len, vocab_size),
49
+ and returns logits in the same shape.
50
+ """
51
+ original_shape = logits.shape
52
+ if logits.dim() == 3:
53
+ logits = logits.squeeze(0) # shape: (seq_len, vocab_size)
54
+
55
+ logits = logits.clone()
56
+
57
+ # --- Temperature scaling ---
58
+ if temperature != 1.0:
59
+ logits = logits / temperature
60
+
61
+ # --- Top-k filtering ---
62
+ if top_k > 0 and top_k < logits.size(-1):
63
+ topk_vals, _ = torch.topk(logits, top_k, dim=-1)
64
+ thresholds = topk_vals[:, -1].unsqueeze(-1)
65
+ logits = torch.where(logits < thresholds, torch.full_like(logits, float("-inf")), logits)
66
+
67
+ # --- Top-p filtering ---
68
+ if top_p > 0.0 and top_p < 1.0:
69
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
70
+ probs = torch.softmax(sorted_logits, dim=-1)
71
+ cum_probs = probs.cumsum(dim=-1)
72
+
73
+ mask = cum_probs > top_p
74
+ mask[:, 0] = False # always keep top token
75
+
76
+ scatter_mask = torch.zeros_like(logits, dtype=torch.bool).scatter(dim=-1, index=sorted_indices, src=mask)
77
+ logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits)
78
+
79
+ # Restore original shape
80
+ if original_shape[0] == 1:
81
+ logits = logits.unsqueeze(0)
82
+
83
+ return logits
84
+
85
+
86
+ def decode_tokens_safe(tokenizer, token_ids):
87
+ return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
88
+
89
+ def find_answer_start(input_ids, marker_ids):
90
+ for i in range(len(input_ids) - len(marker_ids) + 1):
91
+ if input_ids[i:i + len(marker_ids)] == marker_ids:
92
+ return i + len(marker_ids)
93
+ return None
94
+
95
+ def noisify_answer(input_ids, answer_start, threshold=1.0, is_unmasked=None, mask_token_id=128002):
96
+ noised = input_ids.copy()
97
+ total_len = len(input_ids)
98
+ candidates = [
99
+ i for i in range(answer_start, total_len)
100
+ if is_unmasked is None or not is_unmasked[i]
101
+ ]
102
+ num_to_add = int(threshold * total_len)
103
+ if num_to_add > 0 and len(candidates) > 0:
104
+ newly_masked = rng.choice(candidates, size=min(num_to_add, len(candidates)), replace=False)
105
+ for idx in newly_masked:
106
+ noised[idx] = mask_token_id
107
+ return noised
108
+
109
+ def get_noising_schedule(i, max_it, sharpness=5.0):
110
+ x = i / max_it
111
+ return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
112
+
113
+ import torch.nn.functional as F
114
+
115
+ def generate_diffusion_text(model, input_ids, answer_start, top_k=0, top_p=1.0, temperature=1.0,
116
+ eos_token_id=None, eos_boost=0.0):
117
+ model.eval()
118
+ with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
119
+ input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
120
+ logits = model(input_ids=input_tensor)["logits"] # (1, seq_len, vocab_size)
121
+
122
+ # Optionally boost or suppress EOS token
123
+ if eos_token_id is not None and eos_boost != 0.0:
124
+ logits[:, :, eos_token_id] += eos_boost
125
+
126
+ # Filter and sample
127
+ filtered_logits = filter_logits(logits, top_k=top_k, top_p=top_p, temperature=temperature)
128
+ probs = F.softmax(filtered_logits, dim=-1).squeeze() # (seq_len, vocab_size)
129
+ probs = torch.clamp(probs, min=1e-8, max=1.0)
130
+ sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
131
+ confidences = probs.gather(1, sampled.unsqueeze(-1)).squeeze(-1)
132
+
133
+ return input_ids[:answer_start] + sampled[answer_start:].tolist(), confidences
134
+
135
+
136
+ def calculate_answer_perplexity(prompt, answer, model_name='gpt2-large'):
137
+ from transformers import AutoTokenizer, AutoModelForCausalLM
138
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
139
+ model = AutoModelForCausalLM.from_pretrained(model_name).eval()
140
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
141
+ model.to(device)
142
+
143
+ full_input = prompt + answer
144
+ enc = tokenizer(full_input, return_tensors="pt")
145
+ input_ids = enc.input_ids.to(device)
146
+
147
+ with torch.no_grad():
148
+ labels = input_ids.clone()
149
+ prompt_len = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
150
+ labels[0, :prompt_len] = -100
151
+ loss = model(input_ids, labels=labels).loss
152
+ return torch.exp(loss).item()
153
+
154
+
155
+ def generate_answer(question: str, model, tokenizer, max_it=16, noise_start=0.5,
156
+ noising_sharpness=5.0, max_length=256, top_k=100, top_p=1.0,
157
+ temperature=1.0, eos_token_id = None, eos_boost = 0.0) -> str:
158
+
159
+ if eos_token_id is None:
160
+ eos_token_id = tokenizer.eos_token_id
161
+ # Format prompt with LLaMA 3 chat template
162
+ prompt = (
163
+ "<|begin_of_text|>\n"
164
+ "<|start_header_id|>system<|end_header_id|>\n"
165
+ "You are a helpful assistant.\n"
166
+ "<|eot_id|>\n"
167
+ "<|start_header_id|>user<|end_header_id|>\n"
168
+ f"{question.strip()}\n"
169
+ "<|start_header_id|>assistant<|end_header_id|>\n"
170
+ )
171
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
172
+ marker = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False)
173
+
174
+ def find_answer_start(ids, marker):
175
+ for i in range(len(ids) - len(marker) + 1):
176
+ if ids[i:i+len(marker)] == marker:
177
+ return i + len(marker)
178
+ return None
179
+
180
+ answer_start = find_answer_start(input_ids, marker)
181
+ if answer_start is None:
182
+ raise ValueError("Assistant marker not found in prompt.")
183
+
184
+ # Pad to max length
185
+ pad_token = tokenizer.eos_token_id
186
+ mask_token = tokenizer.encode("MASK", add_special_tokens=False)[0]
187
+ input_ids = input_ids[:max_length]
188
+ if len(input_ids) < max_length:
189
+ input_ids += [mask_token] * (max_length - len(input_ids))
190
+
191
+ ori_tokens = input_ids
192
+ current_tokens = noisify_answer(ori_tokens, answer_start, threshold=1.0, mask_token_id=mask_token)
193
+
194
+ last_tokens = []
195
+ for step in range(max_it):
196
+ # Generate a new prediction
197
+ current_tokens, confidence_scores = generate_diffusion_text(
198
+ model, current_tokens, answer_start,
199
+ top_k=top_k, top_p=top_p, temperature=temperature,
200
+ eos_token_id=eos_token_id, eos_boost=eos_boost
201
+ )
202
+
203
+ # Display for debugging / tracking
204
+ display_diffusion_output(
205
+ step, max_it, question,
206
+ ori_tokens, current_tokens, confidence_scores,
207
+ answer_start, tokenizer
208
+ )
209
+
210
+ # Early stopping
211
+ last_tokens.append(current_tokens)
212
+ if len(last_tokens) > 4:
213
+ last_tokens.pop(0)
214
+ if all(t == last_tokens[0] for t in last_tokens):
215
+ break
216
+
217
+ # Re-apply noise for next iteration
218
+ if step < max_it - 1:
219
+ threshold = noise_start * get_noising_schedule(step, max_it, sharpness=noising_sharpness)
220
+ current_tokens = noisify_answer(current_tokens, answer_start, threshold=threshold, mask_token_id=mask_token)
221
+
222
+ return tokenizer.decode(current_tokens[answer_start:], skip_special_tokens=True).strip()
llama_diffusion_model.py DELETED
@@ -1,93 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.amp import autocast
4
- from transformers import AutoModelForCausalLM, PreTrainedModel, PretrainedConfig
5
- from transformers.models.llama.modeling_llama import LlamaAttention
6
- from peft import LoraConfig, get_peft_model
7
- import os
8
- from typing import Optional, Tuple
9
-
10
- hf_token = os.getenv("HF_TOKEN")
11
-
12
- class CustomTransformerConfig(PretrainedConfig):
13
- def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
14
- max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
15
- super().__init__(**kwargs)
16
- self.vocab_size = vocab_size
17
- self.hidden_size = hidden_size
18
- self.num_layers = num_layers
19
- self.num_heads = num_heads
20
- self.dropout = dropout
21
- self.prediction_chunk = prediction_chunk
22
- self.max_position_embeddings = max_position_embeddings
23
- self.input_size = prediction_chunk
24
- self.masking_type = masking_type
25
-
26
- class CustomTransformerModel(PreTrainedModel):
27
- config_class = CustomTransformerConfig
28
-
29
- def __init__(self, config):
30
- super().__init__(config)
31
- self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", torch_dtype=torch.float16, device_map="auto", token=hf_token)
32
- self.llama.resize_token_embeddings(config.vocab_size)
33
-
34
- for param in self.llama.parameters():
35
- param.requires_grad = False
36
- for param in self.llama.lm_head.parameters():
37
- param.requires_grad = True
38
-
39
- lora_config = LoraConfig(
40
- r=512, lora_alpha=512, lora_dropout=0.0,
41
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
42
- bias="none", task_type=None
43
- )
44
-
45
- self.llama = get_peft_model(self.llama, lora_config)
46
- self.llama.print_trainable_parameters()
47
-
48
- def forward(self, input_ids, labels=None, **kwargs):
49
- batch_size, seq_len = input_ids.shape
50
- assert seq_len == self.config.prediction_chunk, f"Expected input length {self.config.prediction_chunk}, got {seq_len}"
51
-
52
- # Build attention mask
53
- device = input_ids.device
54
-
55
- masking_type = getattr(self.config, "masking_type", "bidirectional")
56
- if masking_type == 'bidirectional':
57
- base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
58
- elif masking_type == 'bidirectional_masked':
59
- base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
60
- base_mask.fill_diagonal_(False)
61
- elif masking_type == 'unidirectional':
62
- base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
63
- else:
64
- raise ValueError(f"Unknown masking type: {self.config.masking_type}")
65
-
66
- attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
67
- attention_mask = attention_mask.to(dtype=torch.float32) # required for SDPA and Flash attention
68
-
69
-
70
- with autocast("cuda", dtype=torch.float16):
71
- outputs = self.llama(
72
- input_ids,
73
- attention_mask=attention_mask,
74
- output_hidden_states=True,
75
- use_cache=False,
76
- **kwargs
77
- )
78
-
79
- logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
80
-
81
- loss = None
82
- if labels is not None:
83
- assert labels.shape == (batch_size, seq_len), f"Labels shape mismatch: expected ({batch_size}, {seq_len}), got {labels.shape}"
84
- loss_fct = nn.CrossEntropyLoss()
85
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
86
-
87
- return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}
88
-
89
- def disable_dropout(model):
90
- for name, module in model.named_modules():
91
- if isinstance(module, nn.Dropout):
92
- setattr(model, name, nn.Identity())
93
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model_config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class CustomTransformerConfig(PretrainedConfig):
4
+ def __init__(
5
+ self,
6
+ vocab_size=128256,
7
+ hidden_size=4096,
8
+ num_layers=32,
9
+ num_heads=32,
10
+ prediction_chunk=256,
11
+ dropout=0,
12
+ max_position_embeddings=4096,
13
+ masking_type="bidirectional",
14
+ **kwargs
15
+ ):
16
+ super().__init__(**kwargs)
17
+ self.vocab_size = vocab_size
18
+ self.hidden_size = hidden_size
19
+ self.num_layers = num_layers
20
+ self.num_heads = num_heads
21
+ self.dropout = dropout
22
+ self.prediction_chunk = prediction_chunk
23
+ self.max_position_embeddings = max_position_embeddings
24
+ self.input_size = prediction_chunk # alias
25
+ self.masking_type = masking_type
requirements.txt CHANGED
@@ -5,3 +5,4 @@ peft>=0.15.1
5
  accelerate>=0.24.1
6
  gradio>=4.10.0
7
  numpy
 
 
5
  accelerate>=0.24.1
6
  gradio>=4.10.0
7
  numpy
8
+ load_dotenv