Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
# Model configuration | |
model_name = "ai4bharat/IndicBART" | |
# Load tokenizer and model on CPU | |
print("Loading IndicBART tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) | |
print("Loading IndicBART model on CPU...") | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, # Use float32 for better CPU performance | |
device_map="cpu" | |
) | |
# Language mapping | |
LANGUAGE_CODES = { | |
"Assamese": "<2as>", | |
"Bengali": "<2bn>", | |
"English": "<2en>", | |
"Gujarati": "<2gu>", | |
"Hindi": "<2hi>", | |
"Kannada": "<2kn>", | |
"Malayalam": "<2ml>", | |
"Marathi": "<2mr>", | |
"Oriya": "<2or>", | |
"Punjabi": "<2pa>", | |
"Tamil": "<2ta>", | |
"Telugu": "<2te>" | |
} | |
def generate_response(input_text, source_lang, target_lang, task_type, max_length): | |
"""Generate response using IndicBART on CPU""" | |
# Get language codes | |
src_code = LANGUAGE_CODES[source_lang] | |
tgt_code = LANGUAGE_CODES[target_lang] | |
# Format input based on task type | |
if task_type == "Translation": | |
formatted_input = f"{input_text} </s> {src_code}" | |
decoder_start_token = tgt_code | |
elif task_type == "Text Completion": | |
# For completion, use target language | |
formatted_input = f"{input_text} </s> {tgt_code}" | |
decoder_start_token = tgt_code | |
else: # Text Generation | |
formatted_input = f"{input_text} </s> {src_code}" | |
decoder_start_token = tgt_code | |
# Tokenize input (keep on CPU) | |
inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# Get decoder start token id | |
try: | |
decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token) | |
except: | |
# Fallback if the method doesn't exist | |
decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token) | |
# Generate on CPU | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
decoder_start_token_id=decoder_start_token_id, | |
max_length=max_length, | |
num_beams=2, # Reduced for faster CPU inference | |
early_stopping=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
use_cache=True, | |
do_sample=False # Deterministic for CPU | |
) | |
# Decode output | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
return generated_text | |
# Create Gradio interface | |
with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# ๐ฎ๐ณ IndicBART Multilingual Assistant (CPU Version) | |
Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation. | |
**Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English | |
*Note: Running on CPU - responses may take longer than GPU version.* | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter text in any supported language...", | |
lines=3 | |
) | |
output_text = gr.Textbox( | |
label="Generated Output", | |
lines=5, | |
interactive=False | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
clear_btn = gr.Button("Clear", variant="secondary") | |
with gr.Column(scale=1): | |
task_type = gr.Dropdown( | |
choices=["Translation", "Text Completion", "Text Generation"], | |
value="Translation", | |
label="Task Type" | |
) | |
source_lang = gr.Dropdown( | |
choices=list(LANGUAGE_CODES.keys()), | |
value="English", | |
label="Source Language" | |
) | |
target_lang = gr.Dropdown( | |
choices=list(LANGUAGE_CODES.keys()), | |
value="Hindi", | |
label="Target Language" | |
) | |
max_length = gr.Slider( | |
minimum=20, | |
maximum=200, # Reduced for faster CPU processing | |
value=80, | |
step=10, | |
label="Max Length" | |
) | |
# Examples | |
gr.Markdown("### ๐ก Try these examples:") | |
examples = [ | |
["Hello, how are you?", "English", "Hindi", "Translation", 80], | |
["เคฎเฅเค เคเค เคเคพเคคเฅเคฐ เคนเฅเค", "Hindi", "English", "Translation", 80], | |
["เฆเฆฎเฆฟ เฆญเฆพเฆค เฆเฆพเฆ", "Bengali", "English", "Translation", 80], | |
["เคญเคพเคฐเคค เคเค", "Hindi", "Hindi", "Text Completion", 100], | |
["The capital of India", "English", "English", "Text Completion", 80] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[input_text, source_lang, target_lang, task_type, max_length], | |
outputs=output_text, | |
fn=generate_response | |
) | |
# Event handlers | |
def clear_fields(): | |
return "", "" | |
# Connect buttons | |
generate_btn.click( | |
generate_response, | |
inputs=[input_text, source_lang, target_lang, task_type, max_length], | |
outputs=output_text | |
) | |
clear_btn.click( | |
clear_fields, | |
outputs=[input_text, output_text] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
ssr_mode=False, # Disable SSR mode to fix the 500 error | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |