Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import huggingface_hub | |
import os | |
import torch | |
# --- Configuration --- | |
MODEL_ID = "Fastweb/FastwebMIIA-7B" | |
HF_TOKEN = os.getenv("HF_TOKEN") # For Hugging Face Spaces, set this as a Secret | |
# Global variable to store the pipeline | |
text_generator_pipeline = None | |
model_load_error = None # To store any error message during model loading | |
# --- Hugging Face Login and Model Loading --- | |
def load_model_and_pipeline(): | |
global text_generator_pipeline, model_load_error | |
if text_generator_pipeline is not None: | |
print("Model already loaded.") | |
return True # Already loaded | |
if not HF_TOKEN: | |
model_load_error = "Hugging Face token (HF_TOKEN) not found in Space secrets. Please add it and restart the Space." | |
print(f"ERROR: {model_load_error}") | |
return False | |
try: | |
print(f"Attempting to login to Hugging Face Hub with token...") | |
huggingface_hub.login(token=HF_TOKEN) | |
print("Login successful.") | |
print(f"Loading tokenizer for {MODEL_ID}...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
use_fast=False # As recommended by the model card | |
) | |
# Llama models often don't have a pad token set by default | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Tokenizer loaded.") | |
print(f"Loading model {MODEL_ID}...") | |
# For large models, specify dtype and device_map | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, # Use bfloat16 for better performance and memory if supported | |
device_map="auto" # Automatically distribute model across available GPUs/CPU | |
) | |
print("Model loaded.") | |
text_generator_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
# device_map="auto" handles device placement, so no need for device=0 here | |
) | |
print("Text generation pipeline created successfully.") | |
model_load_error = None | |
return True | |
except Exception as e: | |
model_load_error = f"Error loading model/pipeline: {str(e)}. Check model name, token, and Space resources (RAM/GPU)." | |
print(f"ERROR: {model_load_error}") | |
text_generator_pipeline = None # Ensure it's None on error | |
return False | |
# --- Text Analysis Function --- | |
def analyze_text(text_input, file_upload, custom_instruction, max_new_tokens, temperature, top_p): | |
global text_generator_pipeline, model_load_error | |
if text_generator_pipeline is None: | |
if model_load_error: | |
return f"Model not loaded. Error: {model_load_error}" | |
else: | |
return "Model is not loaded or still loading. Please check Space logs for errors (especially OOM) and ensure HF_TOKEN is set and you've accepted model terms. If on CPU, it may take a very long time or fail due to memory." | |
content_to_analyze = "" | |
if file_upload is not None: | |
try: | |
with open(file_upload.name, 'r', encoding='utf-8') as f: | |
content_to_analyze = f.read() | |
if not content_to_analyze.strip() and not text_input.strip(): | |
return "Uploaded file is empty and no direct text input provided. Please provide some text." | |
elif not content_to_analyze.strip() and text_input.strip(): | |
content_to_analyze = text_input | |
except Exception as e: | |
return f"Error reading uploaded file: {str(e)}" | |
elif text_input: | |
content_to_analyze = text_input | |
else: | |
return "Please provide text directly or upload a document." | |
if not content_to_analyze.strip(): | |
return "Input text is empty." | |
# Using Llama 2 Chat Format | |
# <s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_prompt} [/INST] | |
# For text analysis, the "instruction" is the user_prompt, and the "text_input" is part of it. | |
system_prompt = "You are a helpful AI assistant specialized in text analysis. Perform the requested task on the provided text." | |
user_message = f"{custom_instruction}\n\nHere is the text:\n```\n{content_to_analyze}\n```" | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_message} | |
] | |
try: | |
# Use tokenizer.apply_chat_template if available (transformers >= 4.34.0) | |
prompt = text_generator_pipeline.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
except Exception as e: | |
print(f"Warning: Could not use apply_chat_template ({e}). Falling back to manual formatting.") | |
# Manual Llama 2 chat format | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_message} [/INST]" | |
print(f"\n--- Sending to Model ---") | |
print(f"Full Prompt:\n{prompt}") | |
print(f"Max New Tokens: {max_new_tokens}, Temperature: {temperature}, Top P: {top_p}") | |
print("------------------------\n") | |
try: | |
generated_outputs = text_generator_pipeline( | |
prompt, | |
max_new_tokens=int(max_new_tokens), | |
do_sample=True, | |
temperature=float(temperature) if float(temperature) > 0.01 else 0.01, # Temperature 0 can be problematic | |
top_p=float(top_p), | |
num_return_sequences=1, | |
eos_token_id=text_generator_pipeline.tokenizer.eos_token_id, | |
pad_token_id=text_generator_pipeline.tokenizer.pad_token_id # Use the set pad_token | |
) | |
response_full = generated_outputs[0]['generated_text'] | |
# Extract only the assistant's response part | |
# The model's actual answer starts after the [/INST] token. | |
answer_marker = "[/INST]" | |
if answer_marker in response_full: | |
response_text = response_full.split(answer_marker, 1)[1].strip() | |
else: | |
# Fallback if the full prompt wasn't returned, might happen with some pipeline configs | |
# or if the model didn't fully adhere to the template in its output. | |
# This is less ideal, but better than nothing. | |
response_text = response_full.replace(prompt, "").strip() # Try to remove the input prompt | |
return response_text | |
except Exception as e: | |
return f"Error during text generation: {str(e)}" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f""" | |
# 📝 Text Analysis with {MODEL_ID} | |
Test the capabilities of the `{MODEL_ID}` model for text analysis tasks on Italian or English texts. | |
Provide an instruction and your text (directly or via upload). | |
**Important:** Model loading can take a few minutes, especially on the first run or on CPU. | |
This app is best run on a Hugging Face Space with GPU resources (e.g., T4-small or A10G-small) for this 7B model. | |
""") | |
with gr.Row(): | |
status_textbox = gr.Textbox(label="Model Status", value="Initializing...", interactive=False, scale=3) | |
current_hardware = os.getenv("SPACE_HARDWARE", "Unknown (likely local or unspecified)") | |
gr.Markdown(f"Running on: **{current_hardware}**") | |
with gr.Tab("Text Input & Analysis"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instruction_prompt = gr.Textbox( | |
label="Instruction for the Model (Cosa vuoi fare con il testo?)", | |
value="Riassumi questo testo in 3 frasi concise.", | |
lines=3, | |
placeholder="Example: Riassumi questo testo. / Summarize this text. / Estrai le entità nominate. / Identify named entities." | |
) | |
text_area_input = gr.Textbox(label="Enter Text Directly / Inserisci il testo direttamente", lines=10, placeholder="Paste your text here or upload a file below...") | |
file_input = gr.File(label="Or Upload a Document (.txt) / O carica un documento (.txt)", file_types=['.txt']) | |
with gr.Column(scale=3): | |
output_text = gr.Textbox(label="Model Output / Risultato del Modello", lines=20, interactive=False) | |
with gr.Accordion("Advanced Generation Parameters", open=False): | |
max_new_tokens_slider = gr.Slider(minimum=10, maximum=2048, value=256, step=10, label="Max New Tokens") | |
temperature_slider = gr.Slider(minimum=0.01, maximum=2.0, value=0.7, step=0.01, label="Temperature (higher is more creative, 0.01 for more deterministic)") | |
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P (nucleus sampling)") | |
analyze_button = gr.Button("🧠 Analyze Text / Analizza Testo", variant="primary") | |
analyze_button.click( | |
fn=analyze_text, | |
inputs=[text_area_input, file_input, instruction_prompt, max_new_tokens_slider, temperature_slider, top_p_slider], | |
outputs=output_text | |
) | |
# Load the model when the app starts. | |
# This will update the status_textbox after attempting to load. | |
def startup_load_model(): | |
print("Gradio app starting, attempting to load model...") | |
if load_model_and_pipeline(): | |
return "Model loaded successfully and ready." | |
else: | |
return f"Failed to load model. Error: {model_load_error or 'Unknown error during startup. Check Space logs.'}" | |
demo.load(startup_load_model, outputs=status_textbox) | |
if __name__ == "__main__": | |
# For local testing (ensure HF_TOKEN is set as an environment variable or you're logged in via CLI) | |
# You would run: HF_TOKEN="your_hf_token_here" python app.py | |
if not HF_TOKEN and "HF_TOKEN" not in os.environ: | |
print("WARNING: HF_TOKEN environment variable not set.") | |
print("For local execution, either set HF_TOKEN or ensure you are logged in via 'huggingface-cli login'.") | |
try: | |
from huggingface_hub import HfApi | |
hf_api = HfApi() | |
token = hf_api.token | |
if token: | |
os.environ['HF_TOKEN'] = token # Set it for the current process | |
HF_TOKEN = token # also update the global variable used by the script | |
print("Using token from huggingface-cli login.") | |
else: | |
print("Could not retrieve token from CLI login. Model access might fail.") | |
except Exception as e: | |
print(f"Could not check CLI login status: {e}. Model access might fail.") | |
print("Launching Gradio interface...") | |
demo.queue().launch(debug=True, share=False) # share=True for public link if local | |