File size: 10,799 Bytes
00e773a
018e46d
 
 
301eb87
018e46d
 
 
 
 
 
 
49bf197
018e46d
 
 
 
 
49bf197
018e46d
 
 
49bf197
018e46d
 
 
 
 
 
 
 
 
49bf197
 
 
 
 
 
 
 
018e46d
 
 
 
 
 
 
49bf197
 
018e46d
 
 
 
 
 
 
49bf197
018e46d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49bf197
018e46d
 
 
 
 
 
49bf197
018e46d
49bf197
018e46d
 
 
 
 
 
 
 
 
 
 
49bf197
 
 
018e46d
49bf197
 
018e46d
49bf197
 
 
 
018e46d
49bf197
 
 
 
 
 
 
 
 
 
 
018e46d
 
 
49bf197
018e46d
 
 
 
 
 
 
c256c10
49bf197
018e46d
49bf197
 
 
c256c10
49bf197
018e46d
49bf197
 
 
 
 
018e46d
49bf197
 
 
 
 
 
018e46d
 
 
 
 
 
 
 
 
 
 
49bf197
018e46d
 
 
49bf197
 
 
 
018e46d
 
 
 
 
49bf197
018e46d
49bf197
 
018e46d
49bf197
 
018e46d
49bf197
018e46d
 
49bf197
 
018e46d
 
49bf197
018e46d
 
 
 
 
 
 
 
 
 
49bf197
018e46d
 
 
49bf197
018e46d
 
c256c10
00e773a
 
018e46d
 
 
 
 
 
49bf197
 
 
 
 
 
018e46d
 
 
 
 
c256c10
49bf197
018e46d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import huggingface_hub
import os
import torch

# --- Configuration ---
MODEL_ID = "Fastweb/FastwebMIIA-7B"
HF_TOKEN = os.getenv("HF_TOKEN")  # For Hugging Face Spaces, set this as a Secret

# Global variable to store the pipeline
text_generator_pipeline = None
model_load_error = None # To store any error message during model loading

# --- Hugging Face Login and Model Loading ---
def load_model_and_pipeline():
    global text_generator_pipeline, model_load_error
    if text_generator_pipeline is not None:
        print("Model already loaded.")
        return True # Already loaded

    if not HF_TOKEN:
        model_load_error = "Hugging Face token (HF_TOKEN) not found in Space secrets. Please add it and restart the Space."
        print(f"ERROR: {model_load_error}")
        return False

    try:
        print(f"Attempting to login to Hugging Face Hub with token...")
        huggingface_hub.login(token=HF_TOKEN)
        print("Login successful.")

        print(f"Loading tokenizer for {MODEL_ID}...")
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_ID,
            trust_remote_code=True,
            use_fast=False  # As recommended by the model card
        )
        # Llama models often don't have a pad token set by default
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        print("Tokenizer loaded.")

        print(f"Loading model {MODEL_ID}...")
        # For large models, specify dtype and device_map
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16, # Use bfloat16 for better performance and memory if supported
            device_map="auto"           # Automatically distribute model across available GPUs/CPU
        )
        print("Model loaded.")

        text_generator_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            # device_map="auto" handles device placement, so no need for device=0 here
        )
        print("Text generation pipeline created successfully.")
        model_load_error = None
        return True
    except Exception as e:
        model_load_error = f"Error loading model/pipeline: {str(e)}. Check model name, token, and Space resources (RAM/GPU)."
        print(f"ERROR: {model_load_error}")
        text_generator_pipeline = None # Ensure it's None on error
        return False

# --- Text Analysis Function ---
def analyze_text(text_input, file_upload, custom_instruction, max_new_tokens, temperature, top_p):
    global text_generator_pipeline, model_load_error

    if text_generator_pipeline is None:
        if model_load_error:
            return f"Model not loaded. Error: {model_load_error}"
        else:
            return "Model is not loaded or still loading. Please check Space logs for errors (especially OOM) and ensure HF_TOKEN is set and you've accepted model terms. If on CPU, it may take a very long time or fail due to memory."

    content_to_analyze = ""
    if file_upload is not None:
        try:
            with open(file_upload.name, 'r', encoding='utf-8') as f:
                content_to_analyze = f.read()
            if not content_to_analyze.strip() and not text_input.strip():
                 return "Uploaded file is empty and no direct text input provided. Please provide some text."
            elif not content_to_analyze.strip() and text_input.strip():
                content_to_analyze = text_input
        except Exception as e:
            return f"Error reading uploaded file: {str(e)}"
    elif text_input:
        content_to_analyze = text_input
    else:
        return "Please provide text directly or upload a document."

    if not content_to_analyze.strip():
        return "Input text is empty."

    # Using Llama 2 Chat Format
    # <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_prompt} [/INST]
    # For text analysis, the "instruction" is the user_prompt, and the "text_input" is part of it.

    system_prompt = "You are a helpful AI assistant specialized in text analysis. Perform the requested task on the provided text."
    user_message = f"{custom_instruction}\n\nHere is the text:\n```\n{content_to_analyze}\n```"

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message}
    ]

    try:
        # Use tokenizer.apply_chat_template if available (transformers >= 4.34.0)
        prompt = text_generator_pipeline.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
    except Exception as e:
        print(f"Warning: Could not use apply_chat_template ({e}). Falling back to manual formatting.")
        # Manual Llama 2 chat format
        prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_message} [/INST]"


    print(f"\n--- Sending to Model ---")
    print(f"Full Prompt:\n{prompt}")
    print(f"Max New Tokens: {max_new_tokens}, Temperature: {temperature}, Top P: {top_p}")
    print("------------------------\n")

    try:
        generated_outputs = text_generator_pipeline(
            prompt,
            max_new_tokens=int(max_new_tokens),
            do_sample=True,
            temperature=float(temperature) if float(temperature) > 0.01 else 0.01, # Temperature 0 can be problematic
            top_p=float(top_p),
            num_return_sequences=1,
            eos_token_id=text_generator_pipeline.tokenizer.eos_token_id,
            pad_token_id=text_generator_pipeline.tokenizer.pad_token_id # Use the set pad_token
        )
        response_full = generated_outputs[0]['generated_text']

        # Extract only the assistant's response part
        # The model's actual answer starts after the [/INST] token.
        answer_marker = "[/INST]"
        if answer_marker in response_full:
            response_text = response_full.split(answer_marker, 1)[1].strip()
        else:
            # Fallback if the full prompt wasn't returned, might happen with some pipeline configs
            # or if the model didn't fully adhere to the template in its output.
            # This is less ideal, but better than nothing.
            response_text = response_full.replace(prompt, "").strip() # Try to remove the input prompt

        return response_text

    except Exception as e:
        return f"Error during text generation: {str(e)}"

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"""
    # 📝 Text Analysis with {MODEL_ID}
    Test the capabilities of the `{MODEL_ID}` model for text analysis tasks on Italian or English texts.
    Provide an instruction and your text (directly or via upload).
    **Important:** Model loading can take a few minutes, especially on the first run or on CPU.
    This app is best run on a Hugging Face Space with GPU resources (e.g., T4-small or A10G-small) for this 7B model.
    """)

    with gr.Row():
        status_textbox = gr.Textbox(label="Model Status", value="Initializing...", interactive=False, scale=3)
        current_hardware = os.getenv("SPACE_HARDWARE", "Unknown (likely local or unspecified)")
        gr.Markdown(f"Running on: **{current_hardware}**")


    with gr.Tab("Text Input & Analysis"):
        with gr.Row():
            with gr.Column(scale=2):
                instruction_prompt = gr.Textbox(
                    label="Instruction for the Model (Cosa vuoi fare con il testo?)",
                    value="Riassumi questo testo in 3 frasi concise.",
                    lines=3,
                    placeholder="Example: Riassumi questo testo. / Summarize this text. / Estrai le entità nominate. / Identify named entities."
                )
                text_area_input = gr.Textbox(label="Enter Text Directly / Inserisci il testo direttamente", lines=10, placeholder="Paste your text here or upload a file below...")
                file_input = gr.File(label="Or Upload a Document (.txt) / O carica un documento (.txt)", file_types=['.txt'])
            with gr.Column(scale=3):
                output_text = gr.Textbox(label="Model Output / Risultato del Modello", lines=20, interactive=False)

        with gr.Accordion("Advanced Generation Parameters", open=False):
            max_new_tokens_slider = gr.Slider(minimum=10, maximum=2048, value=256, step=10, label="Max New Tokens")
            temperature_slider = gr.Slider(minimum=0.01, maximum=2.0, value=0.7, step=0.01, label="Temperature (higher is more creative, 0.01 for more deterministic)")
            top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P (nucleus sampling)")

        analyze_button = gr.Button("🧠 Analyze Text / Analizza Testo", variant="primary")

    analyze_button.click(
        fn=analyze_text,
        inputs=[text_area_input, file_input, instruction_prompt, max_new_tokens_slider, temperature_slider, top_p_slider],
        outputs=output_text
    )

    # Load the model when the app starts.
    # This will update the status_textbox after attempting to load.
    def startup_load_model():
        print("Gradio app starting, attempting to load model...")
        if load_model_and_pipeline():
            return "Model loaded successfully and ready."
        else:
            return f"Failed to load model. Error: {model_load_error or 'Unknown error during startup. Check Space logs.'}"

    demo.load(startup_load_model, outputs=status_textbox)


if __name__ == "__main__":
    # For local testing (ensure HF_TOKEN is set as an environment variable or you're logged in via CLI)
    # You would run: HF_TOKEN="your_hf_token_here" python app.py
    if not HF_TOKEN and "HF_TOKEN" not in os.environ:
        print("WARNING: HF_TOKEN environment variable not set.")
        print("For local execution, either set HF_TOKEN or ensure you are logged in via 'huggingface-cli login'.")
        try:
            from huggingface_hub import HfApi
            hf_api = HfApi()
            token = hf_api.token
            if token:
                os.environ['HF_TOKEN'] = token # Set it for the current process
                HF_TOKEN = token # also update the global variable used by the script
                print("Using token from huggingface-cli login.")
            else:
                print("Could not retrieve token from CLI login. Model access might fail.")
        except Exception as e:
            print(f"Could not check CLI login status: {e}. Model access might fail.")

    print("Launching Gradio interface...")
    demo.queue().launch(debug=True, share=False) # share=True for public link if local