ginipick's picture
Update app.py
11034e2 verified
raw
history blame
9.24 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
import os
import datetime
import time
import spaces
# --- μ„€μ • ---
MODEL_ID = "HyperCLOVAX-SEED-Vision-Instruct-3B"
MAX_NEW_TOKENS = 512
CPU_THREAD_COUNT = 4 # ν•„μš”μ‹œ 쑰절
# --- 선택 사항: CPU μŠ€λ ˆλ“œ μ„€μ • ---
# torch.set_num_threads(CPU_THREAD_COUNT)
# os.environ["OMP_NUM_THREADS"] = str(CPU_THREAD_COUNT)
# os.environ["MKL_NUM_THREADS"] = str(CPU_THREAD_COUNT)
print("--- ν™˜κ²½ μ„€μ • ---")
print(f"PyTorch 버전: {torch.__version__}")
print(f"μ‹€ν–‰ μž₯치: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
print(f"Torch μŠ€λ ˆλ“œ: {torch.get_num_threads()}")
# --- λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ”© ---
print(f"--- λͺ¨λΈ λ‘œλ”© 쀑: {MODEL_ID} ---")
print("첫 μ‹€ν–‰ μ‹œ λͺ‡ λΆ„ 정도 μ†Œμš”λ  수 μžˆμŠ΅λ‹ˆλ‹€...")
model = None
tokenizer = None
load_successful = False
stop_token_ids_list = [] # stop_token_ids_list μ΄ˆκΈ°ν™”
try:
start_load_time = time.time()
# μžμ›μ— 따라 device_map μ„€μ •
device_map = "auto" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True
)
model.eval()
load_time = time.time() - start_load_time
print(f"--- λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ™„λ£Œ: {load_time:.2f}초 μ†Œμš” ---")
load_successful = True
# --- 쀑지 토큰 μ„€μ • ---
stop_token_strings = ["</s>", "<|endoftext|>"]
temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
temp_stop_ids.append(tokenizer.eos_token_id)
elif tokenizer.eos_token_id is None:
print("κ²½κ³ : tokenizer.eos_token_idκ°€ Noneμž…λ‹ˆλ‹€. 쀑지 토큰에 μΆ”κ°€ν•  수 μ—†μŠ΅λ‹ˆλ‹€.")
stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None]
if not stop_token_ids_list:
print("κ²½κ³ : 쀑지 토큰 IDλ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. κ°€λŠ₯ν•˜λ©΄ κΈ°λ³Έ EOSλ₯Ό μ‚¬μš©ν•˜κ³ , κ·Έλ ‡μ§€ μ•ŠμœΌλ©΄ 생성이 μ˜¬λ°”λ₯΄κ²Œ μ€‘μ§€λ˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
if tokenizer.eos_token_id is not None:
stop_token_ids_list = [tokenizer.eos_token_id]
else:
print("였λ₯˜: κΈ°λ³Έ EOSλ₯Ό ν¬ν•¨ν•˜μ—¬ 쀑지 토큰을 찾을 수 μ—†μŠ΅λ‹ˆλ‹€. 생성이 λ¬΄ν•œμ • 싀행될 수 μžˆμŠ΅λ‹ˆλ‹€.")
print(f"μ‚¬μš©ν•  쀑지 토큰 ID: {stop_token_ids_list}")
except Exception as e:
print(f"!!! λͺ¨λΈ λ‘œλ”© 였λ₯˜: {e}")
if 'model' in locals() and model is not None: del model
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
gc.collect()
raise gr.Error(f"λͺ¨λΈ {MODEL_ID} λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. μ• ν”Œλ¦¬μΌ€μ΄μ…˜μ„ μ‹œμž‘ν•  수 μ—†μŠ΅λ‹ˆλ‹€. 였λ₯˜: {e}")
# --- μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ μ •μ˜ ---
def get_system_prompt():
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
return (
f"- μ˜€λŠ˜μ€ {current_date}μž…λ‹ˆλ‹€.\n"
f"- μ‚¬μš©μžμ˜ μ§ˆλ¬Έμ— λŒ€ν•΄ μΉœμ ˆν•˜κ³  μžμ„Έν•˜κ²Œ ν•œκ΅­μ–΄λ‘œ λ‹΅λ³€ν•΄μ•Ό ν•©λ‹ˆλ‹€."
)
# --- μ›œμ—… ν•¨μˆ˜ ---
def warmup_model():
if not load_successful or model is None or tokenizer is None:
print("μ›œμ—… κ±΄λ„ˆλ›°κΈ°: λͺ¨λΈμ΄ μ„±κ³΅μ μœΌλ‘œ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
return
print("--- λͺ¨λΈ μ›œμ—… μ‹œμž‘ ---")
try:
start_warmup_time = time.time()
warmup_message = "μ•ˆλ…•ν•˜μ„Έμš”"
# λͺ¨λΈμ— λ§žλŠ” ν˜•μ‹μœΌλ‘œ μž…λ ₯ ꡬ성
system_prompt = get_system_prompt()
# MiMo λͺ¨λΈμ˜ ν”„λ‘¬ν”„νŠΈ ν˜•μ‹μ— 맞게 μ‘°μ •
prompt = f"Human: {warmup_message}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# 쀑지 토큰이 λΉ„μ–΄ μžˆλŠ”μ§€ ν™•μΈν•˜κ³  적절히 처리
gen_kwargs = {
"max_new_tokens": 10,
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
"do_sample": False
}
if stop_token_ids_list:
gen_kwargs["eos_token_id"] = stop_token_ids_list
else:
print("μ›œμ—… κ²½κ³ : 생성에 μ •μ˜λœ 쀑지 토큰이 μ—†μŠ΅λ‹ˆλ‹€.")
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
del inputs
del output_ids
gc.collect()
warmup_time = time.time() - start_warmup_time
print(f"--- λͺ¨λΈ μ›œμ—… μ™„λ£Œ: {warmup_time:.2f}초 μ†Œμš” ---")
except Exception as e:
print(f"!!! λͺ¨λΈ μ›œμ—… 쀑 였λ₯˜ λ°œμƒ: {e}")
finally:
gc.collect()
# --- μΆ”λ‘  ν•¨μˆ˜ ---
@spaces.GPU()
def predict(message, history):
"""
HyperCLOVAX-SEED-Vision-Instruct-3B λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
'history'λŠ” Gradio 'messages' ν˜•μ‹μ„ κ°€μ •ν•©λ‹ˆλ‹€: List[Dict].
"""
if model is None or tokenizer is None:
return "였λ₯˜: λͺ¨λΈμ΄ λ‘œλ“œλ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."
# λŒ€ν™” 기둝 처리
history_text = ""
if isinstance(history, list):
for turn in history:
if isinstance(turn, tuple) and len(turn) == 2:
history_text += f"Human: {turn[0]}\nAssistant: {turn[1]}\n"
# MiMo λͺ¨λΈ μž…λ ₯ ν˜•μ‹μ— 맞게 ν”„λ‘¬ν”„νŠΈ ꡬ성
prompt = f"{history_text}Human: {message}\nAssistant:"
inputs = None
output_ids = None
try:
# μž…λ ₯ μ€€λΉ„
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_length = inputs.input_ids.shape[1]
print(f"\nμž…λ ₯ 토큰 수: {input_length}")
except Exception as e:
print(f"!!! μž…λ ₯ 처리 쀑 였λ₯˜ λ°œμƒ: {e}")
return f"였λ₯˜: μž…λ ₯ ν˜•μ‹μ„ μ²˜λ¦¬ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"
try:
print("응닡 생성 쀑...")
generation_start_time = time.time()
# 생성 인수 μ€€λΉ„, λΉ„μ–΄ μžˆλŠ” stop_token_ids_list 처리
gen_kwargs = {
"max_new_tokens": MAX_NEW_TOKENS,
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1
}
if stop_token_ids_list:
gen_kwargs["eos_token_id"] = stop_token_ids_list
else:
print("생성 κ²½κ³ : μ •μ˜λœ 쀑지 토큰이 μ—†μŠ΅λ‹ˆλ‹€.")
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
generation_time = time.time() - generation_start_time
print(f"생성 μ™„λ£Œ: {generation_time:.2f}초 μ†Œμš”.")
except Exception as e:
print(f"!!! λͺ¨λΈ 생성 쀑 였λ₯˜ λ°œμƒ: {e}")
if inputs is not None: del inputs
if output_ids is not None: del output_ids
gc.collect()
return f"였λ₯˜: 응닡을 μƒμ„±ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. ({e})"
# 응닡 λ””μ½”λ”©
response = "였λ₯˜: 응닡 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€."
if output_ids is not None:
try:
new_tokens = output_ids[0, input_length:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
print(f"좜λ ₯ 토큰 수: {len(new_tokens)}")
del new_tokens
except Exception as e:
print(f"!!! 응닡 λ””μ½”λ”© 쀑 였λ₯˜ λ°œμƒ: {e}")
response = "였λ₯˜: 응닡을 λ””μ½”λ”©ν•˜λŠ” 쀑 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€."
# λ©”λͺ¨λ¦¬ 정리
if inputs is not None: del inputs
if output_ids is not None: del output_ids
gc.collect()
print("λ©”λͺ¨λ¦¬ 정리 μ™„λ£Œ.")
return response.strip()
# --- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • ---
print("--- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • 쀑 ---")
examples = [
["μ•ˆλ…•ν•˜μ„Έμš”! μžκΈ°μ†Œκ°œ μ’€ ν•΄μ£Όμ„Έμš”."],
["인곡지λŠ₯κ³Ό λ¨Έμ‹ λŸ¬λ‹μ˜ 차이점은 λ¬΄μ—‡μΈκ°€μš”?"],
["λ”₯λŸ¬λ‹ λͺ¨λΈ ν•™μŠ΅ 과정을 λ‹¨κ³„λ³„λ‘œ μ•Œλ €μ£Όμ„Έμš”."],
["μ œμ£Όλ„ μ—¬ν–‰ κ³„νšμ„ μ„Έμš°κ³  μžˆλŠ”λ°, 3λ°• 4일 μΆ”μ²œ μ½”μŠ€ μ’€ μ•Œλ €μ£Όμ„Έμš”."],
]
# ChatInterfaceλ₯Ό μ‚¬μš©ν•˜μ—¬ 자체 Chatbot μ»΄ν¬λ„ŒνŠΈ 관리
demo = gr.ChatInterface(
fn=predict,
title="πŸ€– HyperCLOVAX-SEED-Text-Instruct-0.5B",
description=(
f"**λͺ¨λΈ:** {MODEL_ID}\n"
),
examples=examples,
cache_examples=False,
theme=gr.themes.Soft(),
)
# --- μ• ν”Œλ¦¬μΌ€μ΄μ…˜ μ‹€ν–‰ ---
if __name__ == "__main__":
if load_successful:
warmup_model()
else:
print("λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν•˜μ—¬ μ›œμ—…μ„ κ±΄λ„ˆλœλ‹ˆλ‹€.")
print("--- Gradio μ•± μ‹€ν–‰ 쀑 ---")
demo.queue().launch(
# share=True # 곡개 링크λ₯Ό μ›ν•˜λ©΄ 주석 ν•΄μ œ
# server_name="0.0.0.0" # 둜컬 λ„€νŠΈμ›Œν¬ 접근을 μ›ν•˜λ©΄ 주석 ν•΄μ œ
)