File size: 3,294 Bytes
69803e4
2eb4c0d
 
a5e67e1
da587af
2eb4c0d
da587af
2eb4c0d
a5e67e1
2eb4c0d
da587af
 
2eb4c0d
403b84c
a5e67e1
6118e79
764b0a1
 
 
 
 
2eb4c0d
403b84c
da587af
2eb4c0d
 
5573d95
764b0a1
5573d95
2eb4c0d
 
5573d95
764b0a1
5573d95
 
f222dd4
 
 
 
 
 
 
 
 
 
5573d95
da587af
2eb4c0d
5573d95
da587af
5573d95
f222dd4
5573d95
764b0a1
f222dd4
 
5573d95
 
764b0a1
2eb4c0d
 
6118e79
f222dd4
 
 
 
2eb4c0d
403b84c
764b0a1
403b84c
2eb4c0d
 
764b0a1
2eb4c0d
 
764b0a1
7164a5b
2eb4c0d
 
403b84c
2eb4c0d
f222dd4
2eb4c0d
f222dd4
403b84c
 
 
2eb4c0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
import spaces  # Important for ZeroGPU

# Load models (will be moved to GPU when needed)
base_model = AutoModelForCausalLM.from_pretrained(
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    torch_dtype=torch.float16,
    device_map="auto",  # ZeroGPU handles this
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained("unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit")

# Add padding token if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "rezaenayati/RezAi-Model")

@spaces.GPU  # This decorator is CRITICAL for ZeroGPU
def chat_with_rezAi(messages, history):
    conversation = "<|start_header_id|>system<|end_header_id|>\nYou are Reza Enayati, a Computer Science student and entrepreneur from Los Angeles, who is eager to work as a software engineer or machine learning engineer. Answer these questions as if you are in an interview.<|eot_id|>"
    
    # Add conversation history
    for user_msg, assistant_msg in history:
        conversation += f"<|start_header_id|>user<|end_header_id|>\n{user_msg}<|eot_id|>"
        conversation += f"<|start_header_id|>assistant<|end_header_id|>\n{assistant_msg}<|eot_id|>"
    
    # Add current message
    conversation += f"<|start_header_id|>user<|end_header_id|>\n{messages}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
    
    # Tokenize - fix the max_length parameter
    inputs = tokenizer(
        conversation, 
        return_tensors="pt", 
        truncation=True,  # Changed from 'truncate=True'
        max_length=2048
    )
    
    # Move inputs to the same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=128,
            temperature=0.7,  # Slightly increased for more variety
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1  # Added to reduce repetition
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    new_response = response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
    
    # Clean up response - remove any incomplete tags
    if "<|" in new_response:
        new_response = new_response.split("<|")[0].strip()
    
    return new_response

# Create Gradio interface
demo = gr.ChatInterface(
    fn=chat_with_rezAi,
    title="💬 Chat with RezAI",
    description="Hi! I'm RezAI, Reza's AI twin. Ask me about his technical background, projects, or experience!",
    examples=[
        "Tell me about your background",
        "What programming languages do you know?", 
        "Walk me through RezAI",
        "What's your experience with machine learning?",
        "How did you get into computer science?"
    ],
    retry_btn=None,
    undo_btn="Delete Previous", 
    clear_btn="Clear Chat",
    theme=gr.themes.Soft(),  # Added a nice theme
)

if __name__ == "__main__":
    demo.launch()