File size: 5,974 Bytes
da1470a
 
662b714
da1470a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662b714
 
da1470a
 
 
662b714
 
da1470a
662b714
da1470a
662b714
da1470a
 
 
 
662b714
da1470a
 
662b714
 
 
 
da1470a
662b714
da1470a
 
 
662b714
 
da1470a
662b714
da1470a
 
662b714
da1470a
662b714
 
da1470a
 
 
 
 
 
 
 
 
 
662b714
da1470a
 
 
 
 
 
 
662b714
da1470a
 
 
 
662b714
 
da1470a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, logging
from huggingface_hub import login
import torch
import os

# --- 1. Authentication (Choose ONE method and follow the instructions) ---

# Method 1: Environment Variable (RECOMMENDED for security and Hugging Face Spaces)
#   - Set the HUGGING_FACE_HUB_TOKEN environment variable *before* running.
#   - Linux/macOS:  `export HUGGING_FACE_HUB_TOKEN=your_token` (in terminal)
#   - Windows (PowerShell):  `$env:HUGGING_FACE_HUB_TOKEN = "your_token"`
#   - Hugging Face Spaces:  Add `HUGGING_FACE_HUB_TOKEN` as a secret in your Space's settings.
#   - Then, uncomment the following line:
login()

# Method 2: Direct Token (ONLY for local testing, NOT for deployment)
#   - Replace "YOUR_HUGGING_FACE_TOKEN" with your actual token.
#   - WARNING:  Do NOT commit your token to a public repository!
# login(token="YOUR_HUGGING_FACE_TOKEN")

# Method 3: huggingface-cli (Interactive, one-time setup, good for local development)
#   - Run `huggingface-cli login` in your terminal.
#   - Paste your token when prompted.
#   - No code changes are needed after this; the token is stored.

# --- 2. Model and Tokenizer Setup (with comprehensive error handling) ---

def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
    """Loads the model and tokenizer, handling potential errors."""
    try:
        # Suppress unnecessary warning messages from transformers
        logging.set_verbosity_error()

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",  # Automatically use GPU if available, else CPU
            torch_dtype=torch.bfloat16,  # Use bfloat16 for speed/memory if supported
            attn_implementation="flash_attention_2"  # Use Flash Attention 2 if supported
        )
        return model, tokenizer

    except Exception as e:
        print(f"ERROR: Failed to load model or tokenizer: {e}")
        print("\nTroubleshooting Steps:")
        print("1. Ensure you have a Hugging Face account and have accepted the model's terms.")
        print("2. Verify your internet connection.")
        print("3. Double-check the model name: 'google/gemma-3-1b-it'")
        print("4. Ensure you are properly authenticated (see authentication section above).")
        print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.")
        exit(1)  # Exit with an error code

model, tokenizer = load_model_and_tokenizer()


# --- 3. Chat Template Function (CRITICAL for conversational models) ---

def apply_chat_template(messages, tokenizer):
    """Applies the appropriate chat template to the message history.

    Args:
        messages: A list of dictionaries, where each dictionary has 'role' (user/model)
            and 'content' keys.
        tokenizer: The tokenizer object.

    Returns:
        A formatted prompt string ready for the model.
    """
    try:
        if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
            # Use the tokenizer's built-in chat template if available
            return tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            # Fallback to a standard chat template if no specific one is found
            print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.")
            chat_template = "{% for message in messages %}" \
                            "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
                            "{% endfor %}" \
                            "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
            return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)

    except Exception as e:
        print(f"ERROR: Failed to apply chat template: {e}")
        exit(1)


# --- 4. Text Generation Function ---

def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
    """Generates a response using the model and tokenizer."""

    prompt = apply_chat_template(messages, tokenizer)

    try:
        pipeline_instance = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            torch_dtype=torch.bfloat16, # Make sure pipeline also uses correct dtype
            device_map="auto", # and device mapping
            model_kwargs={"attn_implementation": "flash_attention_2"}
            )

        outputs = pipeline_instance(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.eos_token_id,  # Important for proper padding
        )

        # Extract *only* the generated text (remove the prompt)
        generated_text = outputs[0]["generated_text"][len(prompt):].strip()
        return generated_text

    except Exception as e:
        print(f"ERROR: Failed to generate response: {e}")
        return "Sorry, I encountered an error while generating a response."


# --- 5. Main Interaction Loop (for command-line interaction) ---
def main():
    """Main function for interactive command-line chat."""

    messages = []  # Initialize the conversation history

    while True:
        user_input = input("You: ")
        if user_input.lower() in ("exit", "quit", "bye"):
            break

        messages.append({"role": "user", "content": user_input})
        response = generate_response(messages, model, tokenizer)
        print(f"Model: {response}")
        messages.append({"role": "model", "content": response})

if __name__ == "__main__":
    main()