RAGOndevice / app.py
cutechicken's picture
Update app.py
ba10abd verified
raw
history blame
14.3 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
from typing import List, Tuple
import json
from datetime import datetime
# GPU λ©”λͺ¨λ¦¬ 관리
torch.cuda.empty_cache()
# ν™˜κ²½ λ³€μˆ˜ μ„€μ •
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
# λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# μœ„ν‚€ν”Όλ””μ•„ 데이터셋 λ‘œλ“œ
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)
# TF-IDF 벑터라이저 μ΄ˆκΈ°ν™” 및 ν•™μŠ΅
print("TF-IDF 벑터화 μ‹œμž‘...")
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 μ‚¬μš©
vectorizer = TfidfVectorizer(max_features=1000)
question_vectors = vectorizer.fit_transform(questions)
print("TF-IDF 벑터화 μ™„λ£Œ")
class ChatHistory:
def __init__(self):
self.history = []
self.history_file = "/tmp/chat_history.json"
self.load_history()
def add_conversation(self, user_msg: str, assistant_msg: str):
conversation = {
"timestamp": datetime.now().isoformat(),
"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
]
}
self.history.append(conversation)
self.save_history()
def format_for_display(self):
formatted = []
for conv in self.history:
formatted.append([
conv["messages"][0]["content"],
conv["messages"][1]["content"]
])
return formatted
def get_messages_for_api(self):
messages = []
for conv in self.history:
messages.extend([
{"role": "user", "content": conv["messages"][0]["content"]},
{"role": "assistant", "content": conv["messages"][1]["content"]}
])
return messages
def clear_history(self):
self.history = []
self.save_history()
def save_history(self):
try:
with open(self.history_file, 'w', encoding='utf-8') as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"νžˆμŠ€ν† λ¦¬ μ €μž₯ μ‹€νŒ¨: {e}")
def load_history(self):
try:
if os.path.exists(self.history_file):
with open(self.history_file, 'r', encoding='utf-8') as f:
self.history = json.load(f)
except Exception as e:
print(f"νžˆμŠ€ν† λ¦¬ λ‘œλ“œ μ‹€νŒ¨: {e}")
self.history = []
# μ „μ—­ ChatHistory μΈμŠ€ν„΄μŠ€ 생성
chat_history = ChatHistory()
def find_relevant_context(query, top_k=3):
# 쿼리 벑터화
query_vector = vectorizer.transform([query])
# 코사인 μœ μ‚¬λ„ 계산
similarities = (query_vector * question_vectors.T).toarray()[0]
# κ°€μž₯ μœ μ‚¬ν•œ μ§ˆλ¬Έλ“€μ˜ 인덱슀
top_indices = np.argsort(similarities)[-top_k:][::-1]
# κ΄€λ ¨ μ»¨ν…μŠ€νŠΈ μΆ”μΆœ
relevant_contexts = []
for idx in top_indices:
if similarities[idx] > 0:
relevant_contexts.append({
'question': questions[idx],
'answer': wiki_dataset['train']['answer'][idx],
'similarity': similarities[idx]
})
return relevant_contexts
def analyze_file_content(content, file_type):
"""Analyze file content and return structural summary"""
if file_type in ['parquet', 'csv']:
try:
lines = content.split('\n')
header = lines[0]
columns = header.count('|') - 1
rows = len(lines) - 3
return f"πŸ“Š 데이터셋 ꡬ쑰: {columns}개 컬럼, {rows}개 데이터"
except:
return "❌ 데이터셋 ꡬ쑰 뢄석 μ‹€νŒ¨"
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
return f"πŸ’» μ½”λ“œ ꡬ쑰: {total_lines}쀄 (ν•¨μˆ˜: {functions}, 클래슀: {classes}, μž„ν¬νŠΈ: {imports})"
paragraphs = content.count('\n\n') + 1
words = len(content.split())
return f"πŸ“ λ¬Έμ„œ ꡬ쑰: {total_lines}쀄, {paragraphs}단락, μ•½ {words}단어"
def read_uploaded_file(file):
if file is None:
return "", ""
try:
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == '.parquet':
df = pd.read_parquet(file.name, engine='pyarrow')
content = df.head(10).to_markdown(index=False)
return content, "parquet"
elif file_ext == '.csv':
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
df = pd.read_csv(file.name, encoding=encoding)
content = f"πŸ“Š 데이터 미리보기:\n{df.head(10).to_markdown(index=False)}\n\n"
content += f"\nπŸ“ˆ 데이터 정보:\n"
content += f"- 전체 ν–‰ 수: {len(df)}\n"
content += f"- 전체 μ—΄ 수: {len(df.columns)}\n"
content += f"- 컬럼 λͺ©λ‘: {', '.join(df.columns)}\n"
content += f"\nπŸ“‹ 컬럼 데이터 νƒ€μž…:\n"
for col, dtype in df.dtypes.items():
content += f"- {col}: {dtype}\n"
null_counts = df.isnull().sum()
if null_counts.any():
content += f"\n⚠️ 결츑치:\n"
for col, null_count in null_counts[null_counts > 0].items():
content += f"- {col}: {null_count}개 λˆ„λ½\n"
return content, "csv"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"❌ μ§€μ›λ˜λŠ” μΈμ½”λ”©μœΌλ‘œ νŒŒμΌμ„ 읽을 수 μ—†μŠ΅λ‹ˆλ‹€ ({', '.join(encodings)})")
else:
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
with open(file.name, 'r', encoding=encoding) as f:
content = f.read()
return content, "text"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"❌ μ§€μ›λ˜λŠ” μΈμ½”λ”©μœΌλ‘œ νŒŒμΌμ„ 읽을 수 μ—†μŠ΅λ‹ˆλ‹€ ({', '.join(encodings)})")
except Exception as e:
return f"❌ 파일 읽기 였λ₯˜: {str(e)}", "error"
CSS = """
/* 전체 νŽ˜μ΄μ§€ μŠ€νƒ€μΌλ§ */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* 메인 μ»¨ν…Œμ΄λ„ˆ */
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
transform: perspective(1000px) translateZ(0);
transition: all 0.3s ease;
}
/* 제λͺ© μŠ€νƒ€μΌλ§ */
h1 {
color: #2d3436;
font-size: 2.5rem;
text-align: center;
margin-bottom: 2rem;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
transform: perspective(1000px) translateZ(20px);
}
h3 {
text-align: center;
color: #2d3436;
font-size: 1.5rem;
margin: 1rem 0;
}
/* μ±„νŒ…λ°•μŠ€ μŠ€νƒ€μΌλ§ */
.chatbox {
background: white;
border-radius: 15px;
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15);
backdrop-filter: blur(4px);
border: 1px solid rgba(255, 255, 255, 0.18);
padding: 1rem;
margin: 1rem 0;
transform: translateZ(0);
transition: all 0.3s ease;
}
/* λ©”μ‹œμ§€ μŠ€νƒ€μΌλ§ */
.chatbox .messages .message.user {
background: linear-gradient(145deg, #e1f5fe, #bbdefb);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
.chatbox .messages .message.bot {
background: linear-gradient(145deg, #f5f5f5, #eeeeee);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
/* λ²„νŠΌ μŠ€νƒ€μΌλ§ */
.duplicate-button {
background: linear-gradient(145deg, #24292e, #1a1e22) !important;
color: white !important;
border-radius: 100vh !important;
padding: 0.8rem 1.5rem !important;
box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2) !important;
transition: all 0.3s ease !important;
border: none !important;
cursor: pointer !important;
}
.duplicate-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
}
/* μž…λ ₯ ν•„λ“œ μŠ€νƒ€μΌλ§ */
"""
@spaces.GPU
def stream_chat(message: str, history: list, uploaded_file, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
try:
print(f'message is - {message}')
print(f'history is - {history}')
# 파일 μ—…λ‘œλ“œ 처리
file_context = ""
if uploaded_file:
content, file_type = read_uploaded_file(uploaded_file)
if content:
file_context = f"\n\nμ—…λ‘œλ“œλœ 파일 λ‚΄μš©:\n```\n{content}\n```"
# κ΄€λ ¨ μ»¨ν…μŠ€νŠΈ μ°ΎκΈ°
relevant_contexts = find_relevant_context(message)
wiki_context = "\n\nκ΄€λ ¨ μœ„ν‚€ν”Όλ””μ•„ 정보:\n"
for ctx in relevant_contexts:
wiki_context += f"Q: {ctx['question']}\nA: {ctx['answer']}\nμœ μ‚¬λ„: {ctx['similarity']:.3f}\n\n"
# λŒ€ν™” νžˆμŠ€ν† λ¦¬ ꡬ성
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# μ΅œμ’… ν”„λ‘¬ν”„νŠΈ ꡬ성
final_message = file_context + wiki_context + "\nν˜„μž¬ 질문: " + message
conversation.append({"role": "user", "content": final_message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield "", history + [[message, buffer]]
except Exception as e:
error_message = f"였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}"
yield "", history + [[message, error_message]]
# UI λΆ€λΆ„ μˆ˜μ •
with gr.Blocks(css=CSS) as demo:
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=[],
height=500,
label="λŒ€ν™”μ°½",
show_label=True
)
msg = gr.Textbox(
label="λ©”μ‹œμ§€ μž…λ ₯",
show_label=False,
placeholder="무엇이든 λ¬Όμ–΄λ³΄μ„Έμš”... πŸ’­",
container=False
)
with gr.Row():
clear = gr.ClearButton([msg, chatbot], value="λŒ€ν™”λ‚΄μš© μ§€μš°κΈ°")
send = gr.Button("보내기 πŸ“€")
with gr.Column(scale=1):
gr.Markdown("### 파일 μ—…λ‘œλ“œ πŸ“")
file_upload = gr.File(
label="파일 선택",
file_types=["text", ".csv", ".parquet"],
type="filepath"
)
with gr.Accordion("κ³ κΈ‰ μ„€μ • βš™οΈ", open=False):
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="μ˜¨λ„")
max_new_tokens = gr.Slider(minimum=128, maximum=8000, step=1, value=4000, label="μ΅œλŒ€ 토큰 수")
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.8, label="μƒμœ„ ν™•λ₯ ")
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="μƒμœ„ K")
penalty = gr.Slider(minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="반볡 νŒ¨λ„ν‹°")
# 이벀트 바인딩
msg.submit(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
send.click(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
def init_msg():
return "파일 뢄석을 μ‹œμž‘ν•©λ‹ˆλ‹€..."
# 파일 μ—…λ‘œλ“œμ‹œ μžλ™ 뢄석
file_upload.change(
init_msg,
outputs=msg
).then(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
if __name__ == "__main__":
demo.launch()