Yuchan5386 commited on
Commit
7fc84c9
·
verified ·
1 Parent(s): 166a6d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -51,6 +51,7 @@ def text_to_ids(text):
51
  def ids_to_text(ids):
52
  return sp.decode(ids)
53
 
 
54
  class RotaryPositionalEmbedding(layers.Layer):
55
  def __init__(self, dim):
56
  super().__init__()
@@ -94,7 +95,7 @@ class GEGLU(tf.keras.layers.Layer):
94
  return self.out(x_val * tf.nn.gelu(x_gate))
95
 
96
  class KeraLuxBlock(tf.keras.layers.Layer):
97
- def __init__(self, d_model, d_ff, num_heads=20, dropout_rate=0.1):
98
  super().__init__()
99
  self.ln1 = layers.LayerNormalization(epsilon=1e-5)
100
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
@@ -136,7 +137,7 @@ class KeraLuxBlock(tf.keras.layers.Layer):
136
  return x
137
 
138
  class KeraLux(tf.keras.Model):
139
- def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=20, dropout_rate=0.1):
140
  super().__init__()
141
  self.token_embedding = layers.Embedding(vocab_size, d_model)
142
  # pos_embedding 제거
@@ -152,11 +153,17 @@ class KeraLux(tf.keras.Model):
152
  logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
153
  return logits
154
 
155
- # 모델 생성 & 가중치 불러오기
156
- model = KeraLux(vocab_size=vocab_size, seq_len=max_len, d_model=160, d_ff=616, n_layers=6)
 
 
 
 
 
 
157
  dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, 시퀀스길이 max_len
158
  _ = model(dummy_input) # 모델이 빌드됨
159
- model.load_weights("KeraLux3.weights.h5")
160
  print("모델 가중치 로드 완료!")
161
 
162
  def decode_sp_tokens(tokens):
 
51
  def ids_to_text(ids):
52
  return sp.decode(ids)
53
 
54
+
55
  class RotaryPositionalEmbedding(layers.Layer):
56
  def __init__(self, dim):
57
  super().__init__()
 
95
  return self.out(x_val * tf.nn.gelu(x_gate))
96
 
97
  class KeraLuxBlock(tf.keras.layers.Layer):
98
+ def __init__(self, d_model, d_ff, num_heads=12, dropout_rate=0.1):
99
  super().__init__()
100
  self.ln1 = layers.LayerNormalization(epsilon=1e-5)
101
  self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
 
137
  return x
138
 
139
  class KeraLux(tf.keras.Model):
140
+ def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=12, dropout_rate=0.1):
141
  super().__init__()
142
  self.token_embedding = layers.Embedding(vocab_size, d_model)
143
  # pos_embedding 제거
 
153
  logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
154
  return logits
155
 
156
+ # 모델 생성
157
+ model = KeraLux(
158
+ vocab_size=vocab_size,
159
+ seq_len=max_len,
160
+ d_model=192,
161
+ d_ff=768,
162
+ n_layers=6
163
+ )
164
  dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, 시퀀스길이 max_len
165
  _ = model(dummy_input) # 모델이 빌드됨
166
+ model.load_weights("KeraLux.weights.h5")
167
  print("모델 가중치 로드 완료!")
168
 
169
  def decode_sp_tokens(tokens):