import os
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
import gradio as gr
import spaces
# 필요한 경우 Bitnet 지원을 위한 transformers 설치
# Hugging Face Spaces에서는 Dockerfile 등을 통해 미리 설치하는 것이 더 일반적입니다.
# 로컬에서 테스트 시에는 필요할 수 있습니다.
# print("Installing required transformers branch...")
# try:
# os.system("pip install git+https://github.com/shumingma/transformers.git -q")
# print("transformers branch installed.")
# except Exception as e:
# print(f"Error installing transformers branch: {e}")
# print("Proceeding with potentially default transformers version.")
# os.system("pip install accelerate bitsandbytes -q") # bitsandbytes, accelerate도 필요할 수 있습니다.
model_id = "microsoft/bitnet-b1.58-2B-4T"
# 모델 및 토크나이저 로드
print(f"Loading model: {model_id}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# device_map="auto"는 여러 GPU 또는 CPU로 모델을 자동으로 분산 로드합니다.
# bfloat16은 모델 가중치에 사용되는 데이터 타입입니다.
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
# load_in_8bit=True # Bitnet은 1.58bit이므로 8bit 로딩이 의미 없을 수 있습니다.
)
print(f"Model loaded successfully on device: {model.device}")
except Exception as e:
print(f"Error loading model: {e}")
# 모델 로딩 실패 시 더미 모델 사용 또는 오류 처리
class DummyModel:
def generate(self, **kwargs):
# 더미 응답 생성
input_ids = kwargs.get('input_ids')
streamer = kwargs.get('streamer')
if streamer:
# 간단한 더미 응답 스트리밍
dummy_response = "모델 로딩에 실패하여 더미 응답을 제공합니다. 설정/경로를 확인하세요."
for char in dummy_response:
streamer.put(char)
streamer.end()
model = DummyModel()
tokenizer = AutoTokenizer.from_pretrained("gpt2") # 더미 토크나이저
print("Using dummy model due to loading failure.")
@spaces.GPU # Hugging Face Spaces에서 GPU 사용을 지정합니다.
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Generate a chat response using streaming with TextIteratorStreamer.
Args:
message: User's current message.
history: List of (user, assistant) tuples from previous turns.
system_message: Initial system prompt guiding the assistant.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability.
Yields:
The growing response text as new tokens are generated.
"""
# 더미 모델 사용 시 스트리밍 오류 방지
if isinstance(model, DummyModel):
yield "모델 로딩에 실패하여 응답을 생성할 수 없습니다."
return
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
try:
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
# Bitnet 모델에 필요한 추가 인자 설정 (모델 문서 확인 필요)
# 예를 들어, quantize_config 등
)
# 쓰레드에서 모델 생성 실행
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# 스트리머로부터 텍스트를 읽어와 yield
response = ""
for new_text in streamer:
# yield 하기 전에 불필요한 공백/토큰 제거 또는 처리 가능
response += new_text
yield response
except Exception as e:
print(f"Error during response generation: {e}")
yield f"응답 생성 중 오류가 발생했습니다: {e}"
# --- 디자인 개선을 위한 CSS 코드 ---
css_styles = """
/* 전체 페이지 배경 및 기본 폰트 설정 */
body {
font-family: 'Segoe UI', 'Roboto', 'Arial', sans-serif;
line-height: 1.6;
margin: 0;
padding: 20px; /* 앱 주변 여백 추가 */
background-color: #f4f7f6; /* 부드러운 배경색 */
}
/* 메인 앱 컨테이너 스타일 */
.gradio-container {
max-width: 900px; /* 중앙 정렬 및 최대 너비 제한 */
margin: 20px auto;
border-radius: 12px; /* 둥근 모서리 */
overflow: hidden; /* 자식 요소들이 모서리를 넘지 않도록 */
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1); /* 그림자 효과 */
background-color: #ffffff; /* 앱 내용 영역 배경색 */
}
/* 타이틀 및 설명 영역 (ChatInterface의 기본 타이틀/설명) */
/* 이 영역은 ChatInterface 구조에 따라 정확한 클래스 이름이 다를 수 있으나,
.gradio-container 내부의 첫 블록이나 H1/P 태그를 타겟할 수 있습니다.
테마와 함께 사용하면 대부분 잘 처리됩니다. 여기서는 추가적인 패딩 등만 고려 */
.gradio-container > .gradio-block:first-child {
padding: 20px 20px 10px 20px; /* 상단 패딩 조정 */
}
/* 채팅 박스 영역 스타일 */
.gradio-chatbox {
/* 테마에 의해 스타일링되지만, 추가적인 내부 패딩 등 조정 가능 */
padding: 15px;
background-color: #fefefe; /* 채팅 영역 배경색 */
border-radius: 8px; /* 채팅 영역 내부 모서리 */
border: 1px solid #e0e0e0; /* 경계선 */
}
/* 채팅 메시지 스타일 */
.gradio-chatmessage {
margin-bottom: 12px;
padding: 10px 15px;
border-radius: 20px; /* 둥근 메시지 모서리 */
max-width: 75%; /* 메시지 너비 제한 */
word-wrap: break-word; /* 긴 단어 줄바꿈 */
white-space: pre-wrap; /* 공백 및 줄바꿈 유지 */
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); /* 메시지에 약간의 그림자 */
}
/* 사용자 메시지 스타일 */
.gradio-chatmessage.user {
background-color: #007bff; /* 파란색 계열 */
color: white;
margin-left: auto; /* 오른쪽 정렬 */
border-bottom-right-radius: 2px; /* 오른쪽 아래 모서리 각지게 */
}
/* 봇 메시지 스타일 */
.gradio-chatmessage.bot {
background-color: #e9ecef; /* 밝은 회색 */
color: #333; /* 어두운 텍스트 */
margin-right: auto; /* 왼쪽 정렬 */
border-bottom-left-radius: 2px; /* 왼쪽 아래 모서리 각지게 */
}
/* 입력창 및 버튼 영역 스타일 */
.gradio-input-box {
padding: 15px;
border-top: 1px solid #eee; /* 위쪽 경계선 */
background-color: #f8f9fa; /* 입력 영역 배경색 */
}
/* 입력 텍스트 에어리어 스타일 */
.gradio-input-box textarea {
border-radius: 8px;
padding: 10px;
border: 1px solid #ccc;
resize: none !important; /* 입력창 크기 조절 비활성화 (선택 사항) */
min-height: 50px; /* 최소 높이 */
max-height: 150px; /* 최대 높이 */
overflow-y: auto; /* 내용 넘칠 경우 스크롤 */
}
/* 스크롤바 스타일 (선택 사항) */
.gradio-input-box textarea::-webkit-scrollbar {
width: 8px;
}
.gradio-input-box textarea::-webkit-scrollbar-thumb {
background-color: #ccc;
border-radius: 4px;
}
.gradio-input-box textarea::-webkit-scrollbar-track {
background-color: #f1f1f1;
}
/* 버튼 스타일 */
.gradio-button {
border-radius: 8px;
padding: 10px 20px;
font-weight: bold;
transition: background-color 0.2s ease, opacity 0.2s ease; /* 호버 애니메이션 */
border: none; /* 기본 테두리 제거 */
cursor: pointer;
}
.gradio-button:not(.clear-button) { /* Send 버튼 */
background-color: #28a745; /* 초록색 */
color: white;
}
.gradio-button:not(.clear-button):hover {
background-color: #218838;
}
.gradio-button:disabled { /* 비활성화된 버튼 */
opacity: 0.6;
cursor: not-allowed;
}
.gradio-button.clear-button { /* Clear 버튼 */
background-color: #dc3545; /* 빨간색 */
color: white;
}
.gradio-button.clear-button:hover {
background-color: #c82333;
}
/* Additional inputs (추가 설정) 영역 스타일 */
/* 이 영역은 보통 아코디언 형태로 되어 있으며, .gradio-accordion 클래스를 가집니다. */
.gradio-accordion {
border-radius: 12px; /* 외부 컨테이너와 동일한 모서리 */
margin-top: 15px; /* 채팅 영역과의 간격 */
border: 1px solid #ddd; /* 경계선 */
box-shadow: none; /* 내부 그림자 제거 */
}
/* 아코디언 헤더 (라벨) 스타일 */
.gradio-accordion .label {
font-weight: bold;
color: #007bff; /* 파란색 계열 */
padding: 15px; /* 헤더 패딩 */
background-color: #e9ecef; /* 헤더 배경색 */
border-bottom: 1px solid #ddd; /* 헤더 아래 경계선 */
border-top-left-radius: 11px; /* 상단 모서리 */
border-top-right-radius: 11px;
}
/* 아코디언 내용 영역 스타일 */
.gradio-accordion .wrap {
padding: 15px; /* 내용 패딩 */
background-color: #fefefe; /* 내용 배경색 */
border-bottom-left-radius: 11px; /* 하단 모서리 */
border-bottom-right-radius: 11px;
}
/* 추가 설정 내 개별 입력 컴포넌트 스타일 (슬라이더, 텍스트박스 등) */
.gradio-slider, .gradio-textbox, .gradio-number {
margin-bottom: 10px; /* 각 입력 요소 아래 간격 */
padding: 8px; /* 내부 패딩 */
border: 1px solid #e0e0e0; /* 경계선 */
border-radius: 8px; /* 둥근 모서리 */
background-color: #fff; /* 배경색 */
}
/* 입력 필드 라벨 스타일 */
.gradio-label {
font-weight: normal; /* 라벨 폰트 굵기 */
margin-bottom: 5px; /* 라벨과 입력 필드 간 간격 */
color: #555; /* 라벨 색상 */
display: block; /* 라벨을 블록 요소로 만들어 위로 올림 */
}
/* 슬라이더 트랙 및 핸들 스타일 (더 세밀한 조정 가능) */
/* 예: .gradio-slider input[type="range"]::-webkit-slider-thumb {} */
/* 마크다운/HTML 컴포넌트 내 스타일 */
.gradio-markdown, .gradio-html {
padding: 10px 0; /* 상하 패딩 */
}
"""
# --- 디자인 개선을 위한 CSS 코드 끝 ---
# Gradio 인터페이스 설정
demo = gr.ChatInterface(
fn=respond,
# 타이틀 및 설명에 HTML 태그 사용 예시 (
태그 사용)
title="
This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.
", examples=[ [ "Hello! How are you?", "You are a helpful AI assistant for everyday tasks.", 512, 0.7, 0.95, ], [ "Can you code a snake game in Python?", "You are a helpful AI assistant for coding.", 2048, 0.7, 0.95, ], ], additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant.", label="System message", lines=3 # 시스템 메시지 입력창 높이 조절 ), gr.Slider( minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], # 테마 적용 (여러 테마 중 선택 가능: gr.themes.Soft(), gr.themes.Glass(), gr.themes.Default(), etc.) theme=gr.themes.Soft(), # 커스텀 CSS 적용 css=css_styles, ) # 애플리케이션 실행 if __name__ == "__main__": # launch(share=True)는 퍼블릭 URL 생성 (디버깅/공유 목적, 주의 필요) demo.launch() # demo.launch(debug=True) # 디버깅 모드 활성화