import os
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer, AutoModel, set_seed
import torch
from typing import Optional
import asyncio
import time
import gc
import random # Ditambahkan untuk fallback
# Inisialisasi FastAPI
app = FastAPI(title="LyonPoy AI Chat - CPU Optimized (Prompt Mode)")
# Set seed untuk konsistensi
set_seed(42)
# CPU-Optimized 11 models configuration
# Menyesuaikan max_tokens untuk memberi ruang lebih bagi generasi setelah prompt
MODELS = {
"distil-gpt-2": {
"name": "DistilGPT-2 ⚡",
"model_path": "Lyon28/Distil_GPT-2",
"task": "text-generation",
"max_tokens": 60, # Ditingkatkan
"priority": 1
},
"gpt-2-tinny": {
"name": "GPT-2 Tinny ⚡",
"model_path": "Lyon28/GPT-2-Tinny",
"task": "text-generation",
"max_tokens": 50, # Ditingkatkan
"priority": 1
},
"bert-tinny": {
"name": "BERT Tinny 📊",
"model_path": "Lyon28/Bert-Tinny",
"task": "text-classification",
"max_tokens": 0, # Tidak relevan untuk klasifikasi
"priority": 1
},
"distilbert-base-uncased": {
"name": "DistilBERT 📊",
"model_path": "Lyon28/Distilbert-Base-Uncased",
"task": "text-classification",
"max_tokens": 0, # Tidak relevan untuk klasifikasi
"priority": 1
},
"albert-base-v2": {
"name": "ALBERT Base 📊",
"model_path": "Lyon28/Albert-Base-V2",
"task": "text-classification",
"max_tokens": 0,
"priority": 2
},
"electra-small": {
"name": "ELECTRA Small 📊",
"model_path": "Lyon28/Electra-Small",
"task": "text-classification",
"max_tokens": 0,
"priority": 2
},
"t5-small": {
"name": "T5 Small 🔄",
"model_path": "Lyon28/T5-Small",
"task": "text2text-generation",
"max_tokens": 70, # Ditingkatkan
"priority": 2
},
"gpt-2": {
"name": "GPT-2 Standard",
"model_path": "Lyon28/GPT-2",
"task": "text-generation",
"max_tokens": 70, # Ditingkatkan
"priority": 2
},
"tinny-llama": {
"name": "Tinny Llama",
"model_path": "Lyon28/Tinny-Llama",
"task": "text-generation",
"max_tokens": 80, # Ditingkatkan
"priority": 3
},
"pythia": {
"name": "Pythia",
"model_path": "Lyon28/Pythia",
"task": "text-generation",
"max_tokens": 80, # Ditingkatkan
"priority": 3
},
"gpt-neo": {
"name": "GPT-Neo",
"model_path": "Lyon28/GPT-Neo",
"task": "text-generation",
"max_tokens": 90, # Ditingkatkan
"priority": 3
}
}
class ChatRequest(BaseModel):
message: str # Akan berisi prompt lengkap
model: Optional[str] = "distil-gpt-2"
# Tambahan field untuk prompt terstruktur jika diperlukan di Pydantic,
# tapi untuk saat ini kita akan parse dari 'message'
situasi: Optional[str] = ""
latar: Optional[str] = ""
user_message: str # Pesan pengguna aktual
# CPU-Optimized startup
@app.on_event("startup")
async def load_models_on_startup(): # Mengganti nama fungsi agar unik
app.state.pipelines = {}
app.state.tokenizers = {} # Meskipun tidak secara eksplisit digunakan, baik untuk dimiliki jika diperlukan
# Set CPU optimizations
torch.set_num_threads(2)
os.environ['OMP_NUM_THREADS'] = '2'
os.environ['MKL_NUM_THREADS'] = '2'
os.environ['NUMEXPR_NUM_THREADS'] = '2'
# Set cache
os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache/huggingface'
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
print("🚀 LyonPoy AI Chat - CPU Optimized (Prompt Mode) Ready!")
# Lightweight frontend
@app.get("/", response_class=HTMLResponse)
async def get_frontend():
# Mengambil inspirasi styling dari styles.css dan layout dari chat.html
# Ini adalah versi yang SANGAT disederhanakan dan disematkan
html_content = '''
LyonPoy AI Chat - Prompt Mode
Hello! Atur Situasi, Latar, dan pesanmu di bawah. Lalu kirim!
${new Date().toLocaleTimeString('id-ID', { hour: '2-digit', minute: '2-digit' })}
AI sedang berpikir...
'''
return HTMLResponse(content=html_content)
# CPU-Optimized Chat API
@app.post("/chat")
async def chat(request: ChatRequest):
start_time = time.time()
try:
model_id = request.model.lower()
if model_id not in MODELS:
model_id = "distil-gpt-2"
model_config = MODELS[model_id]
# Pesan dari request sekarang adalah prompt yang sudah terstruktur
# contoh: "Situasi: Santai\nLatar:Tepi sungai\n{{User}}:sayang,danau nya indah ya, (memeluk {{char}} dari samping)\n{{Char}}:"
structured_prompt = request.message
if model_id not in app.state.pipelines:
print(f"⚡ CPU Loading {model_config['name']}...")
pipeline_kwargs = {
"task": model_config["task"],
"model": model_config["model_path"],
"device": -1,
"torch_dtype": torch.float32,
"model_kwargs": {
"torchscript": False,
"low_cpu_mem_usage": True
}
}
if model_config["task"] != "text-classification": # Tokenizer hanya untuk generator
app.state.tokenizers[model_id] = AutoTokenizer.from_pretrained(model_config["model_path"])
app.state.pipelines[model_id] = pipeline(**pipeline_kwargs)
gc.collect()
pipe = app.state.pipelines[model_id]
generated_text = "Output tidak didukung untuk task ini."
if model_config["task"] == "text-generation":
# Hitung panjang prompt dalam token
current_tokenizer = app.state.tokenizers.get(model_id)
if not current_tokenizer: # Fallback jika tokenizer tidak ada di state (seharusnya ada)
current_tokenizer = AutoTokenizer.from_pretrained(model_config["model_path"])
prompt_tokens = current_tokenizer.encode(structured_prompt, return_tensors="pt")
prompt_length_tokens = prompt_tokens.shape[1]
# max_length adalah total (prompt + generated). max_tokens adalah untuk generated.
# Pastikan max_length tidak melebihi kapasitas model (umumnya 512 atau 1024 untuk model kecil)
# dan juga tidak terlalu pendek.
# Beberapa model mungkin memiliki max_position_embeddings yang lebih kecil.
# Kita cap max_length ke sesuatu yang aman seperti 256 atau 512 jika terlalu besar.
# Model_config["max_tokens"] adalah max *new* tokens yang kita inginkan.
# Kita gunakan max_new_tokens langsung jika didukung oleh pipeline, atau atur max_length
# Untuk pipeline generik, max_length adalah yang utama.
# Max length harus lebih besar dari prompt.
# Max new tokens dari config model.
max_new_generated_tokens = model_config["max_tokens"]
max_len_for_generation = prompt_length_tokens + max_new_generated_tokens
# Batasi max_length total agar tidak terlalu besar untuk model kecil.
# Misalnya, GPT-2 memiliki konteks 1024. DistilGPT-2 juga.
# Model yang lebih kecil mungkin memiliki batas yang lebih rendah.
# Mari kita set batas atas yang aman, misal 512 untuk demo ini.
# Sesuaikan jika model spesifik Anda memiliki batas yang berbeda.
absolute_max_len = 512
if hasattr(pipe.model.config, 'max_position_embeddings'):
absolute_max_len = pipe.model.config.max_position_embeddings
max_len_for_generation = min(max_len_for_generation, absolute_max_len)
# Pastikan max_length setidaknya prompt + beberapa token baru
if max_len_for_generation <= prompt_length_tokens + 5 : # +5 token baru minimal
max_len_for_generation = prompt_length_tokens + 5
# Pastikan kita tidak meminta lebih banyak token baru daripada yang diizinkan oleh absolute_max_len
actual_max_new_tokens = max_len_for_generation - prompt_length_tokens
if actual_max_new_tokens <= 0: # Jika prompt sudah terlalu panjang
return {
"response": "Hmm, prompt terlalu panjang untuk model ini. Coba perpendek situasi/latar/pesan.",
"model": model_config["name"],
"status": "error_prompt_too_long",
"processing_time": f"{round((time.time() - start_time) * 1000)}ms"
}
outputs = pipe(
structured_prompt,
max_length=max_len_for_generation, # Total panjang
# max_new_tokens=actual_max_new_tokens, # Lebih disukai jika pipeline mendukungnya secara eksplisit
temperature=0.75, # Sedikit lebih kreatif
do_sample=True,
top_p=0.9, # Memperluas sampling sedikit
pad_token_id=pipe.tokenizer.eos_token_id if hasattr(pipe.tokenizer, 'eos_token_id') else 50256, # 50256 untuk GPT2
num_return_sequences=1,
early_stopping=True,
truncation=True # Penting jika prompt terlalu panjang untuk model
)
generated_text = outputs[0]['generated_text']
# Cleanup: ekstrak hanya teks setelah prompt "{{Char}}:"
char_marker = "{{Char}}:"
if char_marker in generated_text:
generated_text = generated_text.split(char_marker, 1)[-1].strip()
elif generated_text.startswith(structured_prompt): # fallback jika marker tidak ada
generated_text = generated_text[len(structured_prompt):].strip()
# Hapus jika model mengulang bagian prompt user
if request.user_message and generated_text.startswith(request.user_message):
generated_text = generated_text[len(request.user_message):].strip()
# Batasi ke beberapa kalimat atau panjang tertentu untuk kecepatan & relevansi
# Ini bisa lebih fleksibel
sentences = generated_text.split('.')
if len(sentences) > 2: # Ambil 2 kalimat pertama jika ada
generated_text = sentences[0].strip() + ('.' if sentences[0] else '') + \
(sentences[1].strip() + '.' if len(sentences) > 1 and sentences[1] else '')
elif len(generated_text) > 150: # Batas karakter kasar
generated_text = generated_text[:147] + '...'
elif model_config["task"] == "text-classification":
# Untuk klasifikasi, kita gunakan pesan pengguna aktual, bukan prompt terstruktur
user_msg_for_classification = request.user_message if request.user_message else structured_prompt
output = pipe(user_msg_for_classification[:256], truncation=True, max_length=256)[0] # Batasi input
confidence = f"{output['score']:.2f}"
generated_text = f"📊 Klasifikasi pesan '{user_msg_for_classification[:30]}...': {output['label']} (Skor: {confidence})"
elif model_config["task"] == "text2text-generation":
# T5 dan model serupa mungkin memerlukan format input yang sedikit berbeda,
# tapi untuk demo ini kita coba kirim prompt apa adanya.
# Anda mungkin perlu menambahkan prefix task seperti "translate English to German: " untuk T5
# Untuk chat, kita bisa biarkan apa adanya atau gunakan user_message.
user_msg_for_t2t = request.user_message if request.user_message else structured_prompt
outputs = pipe(
user_msg_for_t2t[:256], # Batasi input untuk T5
max_length=model_config["max_tokens"], # Ini adalah max_length untuk output T5
temperature=0.65,
early_stopping=True,
truncation=True
)
generated_text = outputs[0]['generated_text']
if not generated_text or len(generated_text.strip()) < 1:
generated_text = "🤔 Hmm, saya tidak yakin bagaimana merespon. Coba lagi dengan prompt berbeda?"
elif len(generated_text) > 250: # Batas akhir output
generated_text = generated_text[:247] + "..."
processing_time_ms = round((time.time() - start_time) * 1000)
return {
"response": generated_text,
"model": model_config["name"],
"status": "success",
"processing_time": f"{processing_time_ms}ms"
}
except Exception as e:
print(f"❌ CPU Error: {e}")
import traceback
traceback.print_exc() # Print full traceback for debugging
processing_time_ms = round((time.time() - start_time) * 1000)
fallback_responses = [
"🔄 Maaf, ada sedikit gangguan. Coba lagi dengan kata yang lebih simpel?",
"💭 Hmm, sepertinya saya butuh istirahat sejenak. Mungkin pertanyaan lain?",
"⚡ Model sedang dioptimalkan, tunggu sebentar dan coba lagi...",
"🚀 Mungkin coba model lain yang lebih cepat atau prompt yang berbeda?"
]
fallback = random.choice(fallback_responses)
return {
"response": f"{fallback} (Error: {str(e)[:100]})", # Beri sedikit info error
"status": "error",
"model": MODELS.get(model_id, {"name": "Unknown"})["name"] if 'model_id' in locals() else "Unknown",
"processing_time": f"{processing_time_ms}ms"
}
# Optimized inference endpoint (TIDAK DIPERBARUI SECARA RINCI untuk prompt mode baru,
# karena fokus utama adalah pada /chat dan frontendnya. Jika /inference juga perlu prompt mode,
# ia harus mengkonstruksi ChatRequest serupa.)
@app.post("/inference")
async def inference(request: dict):
"""CPU-Optimized inference endpoint - MUNGKIN PERLU PENYESUAIAN UNTUK PROMPT MODE"""
try:
# Untuk prompt mode, 'message' harus menjadi prompt terstruktur lengkap
# Atau endpoint ini harus diubah untuk menerima 'situasi', 'latar', 'user_message'
message = request.get("message", "")
model_id_from_request = request.get("model", "distil-gpt-2") # Harusnya model_id internal
# Jika yang diberikan adalah model path, coba map ke model_id internal
if "/" in model_id_from_request:
model_key_from_path = model_id_from_request.split("/")[-1].lower()
model_mapping = { "distil_gpt-2": "distil-gpt-2", "gpt-2-tinny": "gpt-2-tinny", /* ... (tambahkan semua mapping) ... */ }
internal_model = model_mapping.get(model_key_from_path, "distil-gpt-2")
else: # Asumsikan sudah model_id internal
internal_model = model_id_from_request
# Jika /inference perlu mendukung prompt mode, data yang dikirim ke ChatRequest harus disesuaikan
# Untuk contoh ini, kita asumsikan 'message' adalah user_message saja untuk /inference
# dan situasi/latar default atau tidak digunakan.
# Ini adalah penyederhanaan dan mungkin perlu diubah sesuai kebutuhan.
chat_req_data = {
"message": f"{{User}}: {message}\n{{Char}}:", # Bentuk prompt paling sederhana
"model": internal_model,
"user_message": message # Simpan pesan user asli
}
chat_request_obj = ChatRequest(**chat_req_data)
result = await chat(chat_request_obj)
return {
"result": result.get("response"),
"status": result.get("status"),
"model_used": result.get("model"),
"processing_time": result.get("processing_time", "0ms")
}
except Exception as e:
print(f"❌ Inference Error: {e}")
return {
"result": "🔄 Terjadi kesalahan pada endpoint inference. Coba lagi...",
"status": "error"
}
# Lightweight health check
@app.get("/health")
async def health():
loaded_models_count = len(app.state.pipelines) if hasattr(app.state, 'pipelines') else 0
return {
"status": "healthy",
"platform": "CPU",
"loaded_models": loaded_models_count,
"total_models": len(MODELS),
"optimization": "CPU-Tuned (Prompt Mode)"
}
# Model info endpoint
@app.get("/models")
async def get_models_info(): # Mengganti nama fungsi
return {
"models": [
{
"id": k, "name": v["name"], "task": v["task"],
"max_tokens_generate": v["max_tokens"], "priority": v["priority"],
"cpu_optimized": True
}
for k, v in MODELS.items()
],
"platform": "CPU",
"recommended_for_prompting": ["distil-gpt-2", "gpt-2-tinny", "tinny-llama", "gpt-neo", "pythia", "gpt-2"]
}
# Run with CPU optimizations
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
# Gunakan reload=True untuk pengembangan agar perubahan kode langsung terlihat
# Matikan reload untuk produksi
# uvicorn.run("app:app", host="0.0.0.0", port=port, workers=1, reload=True)
uvicorn.run(
app,
host="0.0.0.0",
port=port,
workers=1,
timeout_keep_alive=30, # Default FastAPI 5 detik, mungkin terlalu pendek untuk loading model
access_log=False
)