import os import threading import torch import torch._dynamo torch._dynamo.config.suppress_errors = True from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) import gradio as gr import spaces # 필요한 경우 Bitnet 지원을 위한 transformers 설치 # Hugging Face Spaces에서는 Dockerfile 등을 통해 미리 설치하는 것이 더 일반적입니다. # 로컬에서 테스트 시에는 필요할 수 있습니다. # print("Installing required transformers branch...") # try: # os.system("pip install git+https://github.com/shumingma/transformers.git -q") # print("transformers branch installed.") # except Exception as e: # print(f"Error installing transformers branch: {e}") # print("Proceeding with potentially default transformers version.") # os.system("pip install accelerate bitsandbytes -q") # bitsandbytes, accelerate도 필요할 수 있습니다. model_id = "microsoft/bitnet-b1.58-2B-4T" # 모델 및 토크나이저 로드 print(f"Loading model: {model_id}") try: tokenizer = AutoTokenizer.from_pretrained(model_id) # device_map="auto"는 여러 GPU 또는 CPU로 모델을 자동으로 분산 로드합니다. # bfloat16은 모델 가중치에 사용되는 데이터 타입입니다. model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", # load_in_8bit=True # Bitnet은 1.58bit이므로 8bit 로딩이 의미 없을 수 있습니다. ) print(f"Model loaded successfully on device: {model.device}") except Exception as e: print(f"Error loading model: {e}") # 모델 로딩 실패 시 더미 모델 사용 또는 오류 처리 class DummyModel: def generate(self, **kwargs): # 더미 응답 생성 input_ids = kwargs.get('input_ids') streamer = kwargs.get('streamer') if streamer: # 간단한 더미 응답 스트리밍 dummy_response = "모델 로딩에 실패하여 더미 응답을 제공합니다. 설정/경로를 확인하세요." for char in dummy_response: streamer.put(char) streamer.end() model = DummyModel() tokenizer = AutoTokenizer.from_pretrained("gpt2") # 더미 토크나이저 print("Using dummy model due to loading failure.") @spaces.GPU # Hugging Face Spaces에서 GPU 사용을 지정합니다. def respond( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, ): """ Generate a chat response using streaming with TextIteratorStreamer. Args: message: User's current message. history: List of (user, assistant) tuples from previous turns. system_message: Initial system prompt guiding the assistant. max_tokens: Maximum number of tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling probability. Yields: The growing response text as new tokens are generated. """ # 더미 모델 사용 시 스트리밍 오류 방지 if isinstance(model, DummyModel): yield "모델 로딩에 실패하여 응답을 생성할 수 없습니다." return messages = [{"role": "system", "content": system_message}] for user_msg, bot_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if bot_msg: messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) try: prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, # Bitnet 모델에 필요한 추가 인자 설정 (모델 문서 확인 필요) # 예를 들어, quantize_config 등 ) # 쓰레드에서 모델 생성 실행 thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() # 스트리머로부터 텍스트를 읽어와 yield response = "" for new_text in streamer: # yield 하기 전에 불필요한 공백/토큰 제거 또는 처리 가능 response += new_text yield response except Exception as e: print(f"Error during response generation: {e}") yield f"응답 생성 중 오류가 발생했습니다: {e}" # --- 디자인 개선을 위한 CSS 코드 --- css_styles = """ /* 전체 페이지 배경 및 기본 폰트 설정 */ body { font-family: 'Segoe UI', 'Roboto', 'Arial', sans-serif; line-height: 1.6; margin: 0; padding: 20px; /* 앱 주변 여백 추가 */ background-color: #f4f7f6; /* 부드러운 배경색 */ } /* 메인 앱 컨테이너 스타일 */ .gradio-container { max-width: 900px; /* 중앙 정렬 및 최대 너비 제한 */ margin: 20px auto; border-radius: 12px; /* 둥근 모서리 */ overflow: hidden; /* 자식 요소들이 모서리를 넘지 않도록 */ box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1); /* 그림자 효과 */ background-color: #ffffff; /* 앱 내용 영역 배경색 */ } /* 타이틀 및 설명 영역 (ChatInterface의 기본 타이틀/설명) */ /* 이 영역은 ChatInterface 구조에 따라 정확한 클래스 이름이 다를 수 있으나, .gradio-container 내부의 첫 블록이나 H1/P 태그를 타겟할 수 있습니다. 테마와 함께 사용하면 대부분 잘 처리됩니다. 여기서는 추가적인 패딩 등만 고려 */ .gradio-container > .gradio-block:first-child { padding: 20px 20px 10px 20px; /* 상단 패딩 조정 */ } /* 채팅 박스 영역 스타일 */ .gradio-chatbox { /* 테마에 의해 스타일링되지만, 추가적인 내부 패딩 등 조정 가능 */ padding: 15px; background-color: #fefefe; /* 채팅 영역 배경색 */ border-radius: 8px; /* 채팅 영역 내부 모서리 */ border: 1px solid #e0e0e0; /* 경계선 */ } /* 채팅 메시지 스타일 */ .gradio-chatmessage { margin-bottom: 12px; padding: 10px 15px; border-radius: 20px; /* 둥근 메시지 모서리 */ max-width: 75%; /* 메시지 너비 제한 */ word-wrap: break-word; /* 긴 단어 줄바꿈 */ white-space: pre-wrap; /* 공백 및 줄바꿈 유지 */ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); /* 메시지에 약간의 그림자 */ } /* 사용자 메시지 스타일 */ .gradio-chatmessage.user { background-color: #007bff; /* 파란색 계열 */ color: white; margin-left: auto; /* 오른쪽 정렬 */ border-bottom-right-radius: 2px; /* 오른쪽 아래 모서리 각지게 */ } /* 봇 메시지 스타일 */ .gradio-chatmessage.bot { background-color: #e9ecef; /* 밝은 회색 */ color: #333; /* 어두운 텍스트 */ margin-right: auto; /* 왼쪽 정렬 */ border-bottom-left-radius: 2px; /* 왼쪽 아래 모서리 각지게 */ } /* 입력창 및 버튼 영역 스타일 */ .gradio-input-box { padding: 15px; border-top: 1px solid #eee; /* 위쪽 경계선 */ background-color: #f8f9fa; /* 입력 영역 배경색 */ } /* 입력 텍스트 에어리어 스타일 */ .gradio-input-box textarea { border-radius: 8px; padding: 10px; border: 1px solid #ccc; resize: none !important; /* 입력창 크기 조절 비활성화 (선택 사항) */ min-height: 50px; /* 최소 높이 */ max-height: 150px; /* 최대 높이 */ overflow-y: auto; /* 내용 넘칠 경우 스크롤 */ } /* 스크롤바 스타일 (선택 사항) */ .gradio-input-box textarea::-webkit-scrollbar { width: 8px; } .gradio-input-box textarea::-webkit-scrollbar-thumb { background-color: #ccc; border-radius: 4px; } .gradio-input-box textarea::-webkit-scrollbar-track { background-color: #f1f1f1; } /* 버튼 스타일 */ .gradio-button { border-radius: 8px; padding: 10px 20px; font-weight: bold; transition: background-color 0.2s ease, opacity 0.2s ease; /* 호버 애니메이션 */ border: none; /* 기본 테두리 제거 */ cursor: pointer; } .gradio-button:not(.clear-button) { /* Send 버튼 */ background-color: #28a745; /* 초록색 */ color: white; } .gradio-button:not(.clear-button):hover { background-color: #218838; } .gradio-button:disabled { /* 비활성화된 버튼 */ opacity: 0.6; cursor: not-allowed; } .gradio-button.clear-button { /* Clear 버튼 */ background-color: #dc3545; /* 빨간색 */ color: white; } .gradio-button.clear-button:hover { background-color: #c82333; } /* Additional inputs (추가 설정) 영역 스타일 */ /* 이 영역은 보통 아코디언 형태로 되어 있으며, .gradio-accordion 클래스를 가집니다. */ .gradio-accordion { border-radius: 12px; /* 외부 컨테이너와 동일한 모서리 */ margin-top: 15px; /* 채팅 영역과의 간격 */ border: 1px solid #ddd; /* 경계선 */ box-shadow: none; /* 내부 그림자 제거 */ } /* 아코디언 헤더 (라벨) 스타일 */ .gradio-accordion .label { font-weight: bold; color: #007bff; /* 파란색 계열 */ padding: 15px; /* 헤더 패딩 */ background-color: #e9ecef; /* 헤더 배경색 */ border-bottom: 1px solid #ddd; /* 헤더 아래 경계선 */ border-top-left-radius: 11px; /* 상단 모서리 */ border-top-right-radius: 11px; } /* 아코디언 내용 영역 스타일 */ .gradio-accordion .wrap { padding: 15px; /* 내용 패딩 */ background-color: #fefefe; /* 내용 배경색 */ border-bottom-left-radius: 11px; /* 하단 모서리 */ border-bottom-right-radius: 11px; } /* 추가 설정 내 개별 입력 컴포넌트 스타일 (슬라이더, 텍스트박스 등) */ .gradio-slider, .gradio-textbox, .gradio-number { margin-bottom: 10px; /* 각 입력 요소 아래 간격 */ padding: 8px; /* 내부 패딩 */ border: 1px solid #e0e0e0; /* 경계선 */ border-radius: 8px; /* 둥근 모서리 */ background-color: #fff; /* 배경색 */ } /* 입력 필드 라벨 스타일 */ .gradio-label { font-weight: normal; /* 라벨 폰트 굵기 */ margin-bottom: 5px; /* 라벨과 입력 필드 간 간격 */ color: #555; /* 라벨 색상 */ display: block; /* 라벨을 블록 요소로 만들어 위로 올림 */ } /* 슬라이더 트랙 및 핸들 스타일 (더 세밀한 조정 가능) */ /* 예: .gradio-slider input[type="range"]::-webkit-slider-thumb {} */ /* 마크다운/HTML 컴포넌트 내 스타일 */ .gradio-markdown, .gradio-html { padding: 10px 0; /* 상하 패딩 */ } """ # --- 디자인 개선을 위한 CSS 코드 끝 --- # Gradio 인터페이스 설정 demo = gr.ChatInterface( fn=respond, # 타이틀 및 설명에 HTML 태그 사용 예시 (
태그 사용) title="

Bitnet-b1.58-2B-4T Chatbot

", description="

This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.

", examples=[ [ "Hello! How are you?", "You are a helpful AI assistant for everyday tasks.", 512, 0.7, 0.95, ], [ "Can you code a snake game in Python?", "You are a helpful AI assistant for coding.", 2048, 0.7, 0.95, ], ], additional_inputs=[ gr.Textbox( value="You are a helpful AI assistant.", label="System message", lines=3 # 시스템 메시지 입력창 높이 조절 ), gr.Slider( minimum=1, maximum=8192, value=2048, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], # 테마 적용 (여러 테마 중 선택 가능: gr.themes.Soft(), gr.themes.Glass(), gr.themes.Default(), etc.) theme=gr.themes.Soft(), # 커스텀 CSS 적용 css=css_styles, ) # 애플리케이션 실행 if __name__ == "__main__": # launch(share=True)는 퍼블릭 URL 생성 (디버깅/공유 목적, 주의 필요) demo.launch() # demo.launch(debug=True) # 디버깅 모드 활성화