File size: 9,239 Bytes
f852c4e
9267136
e507147
 
 
 
 
e6a16df
9267136
e507147
11034e2
e507147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a97aa70
9267136
e507147
11034e2
e507147
 
 
 
 
 
9267136
e507147
 
 
 
 
 
9267136
e507147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3fa588
9267136
e507147
 
 
 
 
 
 
 
 
 
 
f852c4e
e507147
f0d0d46
e507147
 
f0d0d46
e507147
 
 
 
f852c4e
 
e507147
f852c4e
e507147
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
import os
import datetime
import time
import spaces

# --- μ„€μ • ---
MODEL_ID = "HyperCLOVAX-SEED-Vision-Instruct-3B"
MAX_NEW_TOKENS = 512
CPU_THREAD_COUNT = 4 # ν•„μš”μ‹œ 쑰절

# --- 선택 사항: CPU μŠ€λ ˆλ“œ μ„€μ • ---
# torch.set_num_threads(CPU_THREAD_COUNT)
# os.environ["OMP_NUM_THREADS"] = str(CPU_THREAD_COUNT)
# os.environ["MKL_NUM_THREADS"] = str(CPU_THREAD_COUNT)

print("--- ν™˜κ²½ μ„€μ • ---")
print(f"PyTorch 버전: {torch.__version__}")
print(f"μ‹€ν–‰ μž₯치: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
print(f"Torch μŠ€λ ˆλ“œ: {torch.get_num_threads()}")

# --- λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ”© ---
print(f"--- λͺ¨λΈ λ‘œλ”© 쀑: {MODEL_ID} ---")
print("첫 μ‹€ν–‰ μ‹œ λͺ‡ λΆ„ 정도 μ†Œμš”λ  수 μžˆμŠ΅λ‹ˆλ‹€...")

model = None
tokenizer = None
load_successful = False
stop_token_ids_list = [] # stop_token_ids_list μ΄ˆκΈ°ν™”

try:
    start_load_time = time.time()
    # μžμ›μ— 따라 device_map μ„€μ •
    device_map = "auto" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True
    )
    
    model.eval()
    load_time = time.time() - start_load_time
    print(f"--- λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ™„λ£Œ: {load_time:.2f}초 μ†Œμš” ---")
    load_successful = True

    # --- 쀑지 토큰 μ„€μ • ---
    stop_token_strings = ["</s>", "<|endoftext|>"]
    temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]

    if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
        temp_stop_ids.append(tokenizer.eos_token_id)
    elif tokenizer.eos_token_id is None:
         print("κ²½κ³ : tokenizer.eos_token_idκ°€ Noneμž…λ‹ˆλ‹€. 쀑지 토큰에 μΆ”κ°€ν•  수 μ—†μŠ΅λ‹ˆλ‹€.")

    stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None]

    if not stop_token_ids_list:
        print("κ²½κ³ : 쀑지 토큰 IDλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. κ°€λŠ₯ν•˜λ©΄ κΈ°λ³Έ EOSλ₯Ό μ‚¬μš©ν•˜κ³ , κ·Έλ ‡μ§€ μ•ŠμœΌλ©΄ 생성이 μ˜¬λ°”λ₯΄κ²Œ μ€‘μ§€λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
        if tokenizer.eos_token_id is not None:
            stop_token_ids_list = [tokenizer.eos_token_id]
        else:
             print("였λ₯˜: κΈ°λ³Έ EOSλ₯Ό ν¬ν•¨ν•˜μ—¬ 쀑지 토큰을 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. 생성이 λ¬΄ν•œμ • 싀행될 수 μžˆμŠ΅λ‹ˆλ‹€.")

    print(f"μ‚¬μš©ν•  쀑지 토큰 ID: {stop_token_ids_list}")

except Exception as e:
    print(f"!!! λͺ¨λΈ λ‘œλ”© 였λ₯˜: {e}")
    if 'model' in locals() and model is not None: del model
    if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
    gc.collect()
    raise gr.Error(f"λͺ¨λΈ {MODEL_ID} λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. μ• ν”Œλ¦¬μΌ€μ΄μ…˜μ„ μ‹œμž‘ν•  수 μ—†μŠ΅λ‹ˆλ‹€. 였λ₯˜: {e}")

# --- μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ μ •μ˜ ---
def get_system_prompt():
    current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
    return (
        f"- μ˜€λŠ˜μ€ {current_date}μž…λ‹ˆλ‹€.\n"
        f"- μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— λŒ€ν•΄ μΉœμ ˆν•˜κ³  μžμ„Έν•˜κ²Œ ν•œκ΅­μ–΄λ‘œ λ‹΅λ³€ν•΄μ•Ό ν•©λ‹ˆλ‹€."
    )

# --- μ›œμ—… ν•¨μˆ˜ ---
def warmup_model():
    if not load_successful or model is None or tokenizer is None:
        print("μ›œμ—… κ±΄λ„ˆλ›°κΈ°: λͺ¨λΈμ΄ μ„±κ³΅μ μœΌλ‘œ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
        return

    print("--- λͺ¨λΈ μ›œμ—… μ‹œμž‘ ---")
    try:
        start_warmup_time = time.time()
        warmup_message = "μ•ˆλ…•ν•˜μ„Έμš”"
        
        # λͺ¨λΈμ— λ§žλŠ” ν˜•μ‹μœΌλ‘œ μž…λ ₯ ꡬ성
        system_prompt = get_system_prompt()
        
        # MiMo λͺ¨λΈμ˜ ν”„λ‘¬ν”„νŠΈ ν˜•μ‹μ— 맞게 μ‘°μ •
        prompt = f"Human: {warmup_message}\nAssistant:"
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        # 쀑지 토큰이 λΉ„μ–΄ μžˆλŠ”μ§€ ν™•μΈν•˜κ³  적절히 처리
        gen_kwargs = {
            "max_new_tokens": 10,
            "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
            "do_sample": False
        }
        
        if stop_token_ids_list:
            gen_kwargs["eos_token_id"] = stop_token_ids_list
        else:
            print("μ›œμ—… κ²½κ³ : 생성에 μ •μ˜λœ 쀑지 토큰이 μ—†μŠ΅λ‹ˆλ‹€.")

        with torch.no_grad():
            output_ids = model.generate(**inputs, **gen_kwargs)

        del inputs
        del output_ids
        gc.collect()
        warmup_time = time.time() - start_warmup_time
        print(f"--- λͺ¨λΈ μ›œμ—… μ™„λ£Œ: {warmup_time:.2f}초 μ†Œμš” ---")

    except Exception as e:
        print(f"!!! λͺ¨λΈ μ›œμ—… 쀑 였λ₯˜ λ°œμƒ: {e}")
    finally:
        gc.collect()

# --- μΆ”λ‘  ν•¨μˆ˜ ---
@spaces.GPU()
def predict(message, history):
    """
    HyperCLOVAX-SEED-Vision-Instruct-3B λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
    'history'λŠ” Gradio 'messages' ν˜•μ‹μ„ κ°€μ •ν•©λ‹ˆλ‹€: List[Dict].
    """
    if model is None or tokenizer is None:
         return "였λ₯˜: λͺ¨λΈμ΄ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."

    # λŒ€ν™” 기둝 처리
    history_text = ""
    if isinstance(history, list):
        for turn in history:
            if isinstance(turn, tuple) and len(turn) == 2:
                history_text += f"Human: {turn[0]}\nAssistant: {turn[1]}\n"

    # MiMo λͺ¨λΈ μž…λ ₯ ν˜•μ‹μ— 맞게 ν”„λ‘¬ν”„νŠΈ ꡬ성
    prompt = f"{history_text}Human: {message}\nAssistant:"

    inputs = None
    output_ids = None

    try:
        # μž…λ ₯ μ€€λΉ„
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_length = inputs.input_ids.shape[1]
        print(f"\nμž…λ ₯ 토큰 수: {input_length}")

    except Exception as e:
        print(f"!!! μž…λ ₯ 처리 쀑 였λ₯˜ λ°œμƒ: {e}")
        return f"였λ₯˜: μž…λ ₯ ν˜•μ‹μ„ μ²˜λ¦¬ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"

    try:
        print("응닡 생성 쀑...")
        generation_start_time = time.time()

        # 생성 인수 μ€€λΉ„, λΉ„μ–΄ μžˆλŠ” stop_token_ids_list 처리
        gen_kwargs = {
            "max_new_tokens": MAX_NEW_TOKENS,
            "pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
            "do_sample": True,
            "temperature": 0.7,
            "top_p": 0.9,
            "repetition_penalty": 1.1
        }
        
        if stop_token_ids_list:
             gen_kwargs["eos_token_id"] = stop_token_ids_list
        else:
             print("생성 κ²½κ³ : μ •μ˜λœ 쀑지 토큰이 μ—†μŠ΅λ‹ˆλ‹€.")

        with torch.no_grad():
            output_ids = model.generate(**inputs, **gen_kwargs)

        generation_time = time.time() - generation_start_time
        print(f"생성 μ™„λ£Œ: {generation_time:.2f}초 μ†Œμš”.")

    except Exception as e:
        print(f"!!! λͺ¨λΈ 생성 쀑 였λ₯˜ λ°œμƒ: {e}")
        if inputs is not None: del inputs
        if output_ids is not None: del output_ids
        gc.collect()
        return f"였λ₯˜: 응닡을 μƒμ„±ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"

    # 응닡 λ””μ½”λ”©
    response = "였λ₯˜: 응닡 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€."
    if output_ids is not None:
        try:
            new_tokens = output_ids[0, input_length:]
            response = tokenizer.decode(new_tokens, skip_special_tokens=True)
            print(f"좜λ ₯ 토큰 수: {len(new_tokens)}")
            del new_tokens
        except Exception as e:
            print(f"!!! 응닡 λ””μ½”λ”© 쀑 였λ₯˜ λ°œμƒ: {e}")
            response = "였λ₯˜: 응닡을 λ””μ½”λ”©ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."

    # λ©”λͺ¨λ¦¬ 정리
    if inputs is not None: del inputs
    if output_ids is not None: del output_ids
    gc.collect()
    print("λ©”λͺ¨λ¦¬ 정리 μ™„λ£Œ.")

    return response.strip()

# --- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • ---
print("--- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • 쀑 ---")

examples = [
    ["μ•ˆλ…•ν•˜μ„Έμš”! μžκΈ°μ†Œκ°œ μ’€ ν•΄μ£Όμ„Έμš”."],
    ["인곡지λŠ₯κ³Ό λ¨Έμ‹ λŸ¬λ‹μ˜ 차이점은 λ¬΄μ—‡μΈκ°€μš”?"],
    ["λ”₯λŸ¬λ‹ λͺ¨λΈ ν•™μŠ΅ 과정을 λ‹¨κ³„λ³„λ‘œ μ•Œλ €μ£Όμ„Έμš”."],
    ["μ œμ£Όλ„ μ—¬ν–‰ κ³„νšμ„ μ„Έμš°κ³  μžˆλŠ”λ°, 3λ°• 4일 μΆ”μ²œ μ½”μŠ€ μ’€ μ•Œλ €μ£Όμ„Έμš”."],
]

# ChatInterfaceλ₯Ό μ‚¬μš©ν•˜μ—¬ 자체 Chatbot μ»΄ν¬λ„ŒνŠΈ 관리
demo = gr.ChatInterface(
    fn=predict,
    title="πŸ€– HyperCLOVAX-SEED-Text-Instruct-0.5B",
    description=(
        f"**λͺ¨λΈ:** {MODEL_ID}\n"

    ),
    examples=examples,
    cache_examples=False,
    theme=gr.themes.Soft(),
)

# --- μ• ν”Œλ¦¬μΌ€μ΄μ…˜ μ‹€ν–‰ ---
if __name__ == "__main__":
    if load_successful:
        warmup_model()
    else:
        print("λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν•˜μ—¬ μ›œμ—…μ„ κ±΄λ„ˆλœλ‹ˆλ‹€.")

    print("--- Gradio μ•± μ‹€ν–‰ 쀑 ---")
    demo.queue().launch(
        # share=True # 곡개 링크λ₯Ό μ›ν•˜λ©΄ 주석 ν•΄μ œ
        # server_name="0.0.0.0" # 둜컬 λ„€νŠΈμ›Œν¬ 접근을 μ›ν•˜λ©΄ 주석 ν•΄μ œ
    )