File size: 3,438 Bytes
5bd469b
b1b4c96
5bd469b
d886ca7
b1b4c96
96dadc6
d886ca7
5bd469b
 
d886ca7
b1b4c96
d886ca7
5bd469b
d886ca7
5bd469b
d886ca7
 
 
 
 
 
 
 
 
 
 
b1b4c96
5bd469b
 
d886ca7
 
 
 
 
5bd469b
b1b4c96
d886ca7
5bd469b
 
b1b4c96
 
 
 
d886ca7
 
 
b1b4c96
d886ca7
b1b4c96
 
 
d886ca7
 
 
b1b4c96
 
 
 
d886ca7
 
 
b1b4c96
5bd469b
b1b4c96
5bd469b
d886ca7
 
 
b1b4c96
 
d886ca7
6fd805d
d886ca7
6fd805d
d886ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# Hugging Face Token from Space Secrets
HF_TOKEN = os.environ.get("HF_TOKEN")

# Model IDs
BASE_MODEL = "google/gemma-3-1b-it"
LORA_ADAPTER = "markredito/gemma-pip-finetuned-v2"  # 🔁 Replace with your actual LoRA repo

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Quantization config for 4-bit (recommended on T4 GPU)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    token=HF_TOKEN,
    attn_implementation="eager"  # Required for Gemma3 + quant
)

model = PeftModel.from_pretrained(model, LORA_ADAPTER, token=HF_TOKEN)

# Pad token fallback
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

# Generation function
def generate_response(prompt, temperature, top_p, top_k):
    formatted = (
        "<start_of_turn>user\n"
        f"{prompt.strip()}\n"
        "<end_of_turn>\n"
        "<start_of_turn>model\n"
    )

    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    decoded = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    cleaned = decoded.split("<end_of_turn>")[0].replace("model\n", "").strip()
    return cleaned

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## ✨ Gemma Psychedelic Model Demo")
    gr.Markdown("Use your imagination or try one of the examples below to explore poetic and philosophical responses.")
    gr.Markdown("Note: this model intentionally hallucinates.")

    examples = [
        "Describe a world where clouds are solid and people walk on them",
        "Contrast quantum realities phenomena from the perspective of a starship navigator, using a spiral into infinity.",
        "Dream up futuristic phenomena from the perspective of a timeless oracle, using a fractal blooming in chaos.",
    ]

    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Enter your prompt", lines=4, placeholder="Try something like: What if gravity took a day off?")

            gr.Examples(
                examples=examples,
                inputs=prompt_input,
                label="Example Prompts"
            )

            temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top-p (nucleus sampling)")
            top_k = gr.Slider(0, 100, step=1, value=50, label="Top-k")

            submit = gr.Button("Generate")

        with gr.Column():
            output = gr.Textbox(label="Model Response", lines=10)

    submit.click(fn=generate_response, inputs=[prompt_input, temperature, top_p, top_k], outputs=output)

demo.launch()