UnarineLeo's picture
Create app.py
ee94686 verified
raw
history blame
8.77 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer"""
global model, tokenizer
try:
model_name = "UnarineLeo/nllb_eng_ven_terms"
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print("Model loaded successfully!")
return True
except Exception as e:
print(f"Error loading model: {e}")
return False
def translate_text(text, max_length=512, num_beams=5):
"""
Translate English text to Venda
Args:
text (str): Input English text
max_length (int): Maximum length of translation
num_beams (int): Number of beams for beam search
Returns:
tuple: (translated_text, status_message)
"""
global model, tokenizer
if not text.strip():
return "", "Please enter some text to translate."
if model is None or tokenizer is None:
return "", "Model not loaded. Please wait while the model loads."
try:
# Set source language
tokenizer.src_lang = "eng_Latn"
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# Generate translation
start_time = time.time()
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id["ven_Latn"],
max_length=max_length,
num_beams=num_beams,
early_stopping=True,
do_sample=False
)
# Decode translation
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
end_time = time.time()
processing_time = round(end_time - start_time, 2)
status = f"βœ… Translation completed in {processing_time} seconds"
return translation, status
except Exception as e:
error_msg = f"❌ Translation error: {str(e)}"
return "", error_msg
def translate_batch(text_list):
"""
Translate multiple lines of text
Args:
text_list (str): Multi-line text input
Returns:
tuple: (translated_text, status_message)
"""
if not text_list.strip():
return "", "Please enter some text to translate."
lines = [line.strip() for line in text_list.split('\n') if line.strip()]
if not lines:
return "", "No valid text lines found."
try:
translations = []
total_time = 0
for i, line in enumerate(lines):
translation, status = translate_text(line)
if translation:
translations.append(f"{i+1}. EN: {line}")
translations.append(f" VE: {translation}")
translations.append("")
if translations:
result = "\n".join(translations)
status_msg = f"βœ… Successfully translated {len(lines)} lines"
return result, status_msg
else:
return "", "❌ No translations generated"
except Exception as e:
return "", f"❌ Batch translation error: {str(e)}"
# Load model on startup
print("Initializing model...")
model_loaded = load_model()
# Create Gradio interface
with gr.Blocks(title="English to Venda Translator", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌍 English to Venda Translator
This app translates English text to Venda (Tshivenda) using the NLLB model.
Venda is a Bantu language spoken primarily in South Africa and Zimbabwe.
**Model:** `UnarineLeo/nllb_eng_ven_terms`
""")
with gr.Tab("Single Translation"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="English Text",
placeholder="Enter English text to translate...",
lines=4,
max_lines=10
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=50,
maximum=1000,
value=512,
step=50,
label="Max Translation Length"
)
num_beams_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of Beams (Quality vs Speed)"
)
translate_btn = gr.Button("πŸ”„ Translate", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Venda Translation",
lines=4,
max_lines=10,
interactive=False
)
status_text = gr.Textbox(
label="Status",
interactive=False,
lines=1
)
# Examples
gr.Examples(
examples=[
["Hello, how are you?"],
["Good morning, everyone."],
["Thank you for your help."],
["What is your name?"],
["I am learning Venda."],
["Welcome to our school."],
["The weather is beautiful today."],
["Can you help me please?"]
],
inputs=[input_text],
label="Try these examples:"
)
with gr.Tab("Batch Translation"):
with gr.Row():
with gr.Column():
batch_input = gr.Textbox(
label="Multiple English Sentences",
placeholder="Enter multiple English sentences, one per line...",
lines=8,
max_lines=15
)
batch_translate_btn = gr.Button("πŸ”„ Translate All", variant="primary")
with gr.Column():
batch_output = gr.Textbox(
label="Batch Translations",
lines=8,
max_lines=15,
interactive=False
)
batch_status = gr.Textbox(
label="Status",
interactive=False,
lines=1
)
with gr.Tab("About"):
gr.Markdown("""
## About This Translator
This application uses a fine-tuned NLLB (No Language Left Behind) model specifically trained for English to Venda translation.
### Features:
- **Single Translation**: Translate individual sentences or paragraphs
- **Batch Translation**: Translate multiple sentences at once
- **Adjustable Parameters**: Control translation quality and length
- **Examples**: Try pre-loaded example sentences
### About Venda (Tshivenda):
- Spoken by approximately 1.2 million people
- Official language of South Africa
- Also spoken in Zimbabwe
- Part of the Bantu language family
### Usage Tips:
- Keep sentences reasonably short for best results
- The model works best with common, everyday language
- Higher beam numbers generally produce better quality but slower translations
### Technical Details:
- **Model**: UnarineLeo/nllb_eng_ven_terms
- **Architecture**: NLLB (No Language Left Behind)
- **Language Codes**: eng_Latn β†’ ven_Latn
""")
# Event handlers
translate_btn.click(
fn=translate_text,
inputs=[input_text, max_length_slider, num_beams_slider],
outputs=[output_text, status_text]
)
batch_translate_btn.click(
fn=translate_batch,
inputs=[batch_input],
outputs=[batch_output, batch_status]
)
# Auto-translate on example selection
input_text.submit(
fn=translate_text,
inputs=[input_text, max_length_slider, num_beams_slider],
outputs=[output_text, status_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
share=True,
debug=True,
show_error=True
)