#!/usr/bin/env python # Gradio app for Dhivehi typo correction import difflib from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import gradio as gr import spaces # Available models MODEL_OPTIONS_TYPO = { "A3 Model": "alakxender/t5-dhivehi-typo-corrector-asr", "XS Model": "alakxender/dhivehi-quick-spell-check-t5" } # Function to load model and tokenizer def load_model(model_choice): print("Loading model and tokenizer...") try: selected_model = MODEL_OPTIONS_TYPO[model_choice] tokenizer = AutoTokenizer.from_pretrained(selected_model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForSeq2SeqLM.from_pretrained(selected_model) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) print(f"Model loaded successfully on {device}") return model, tokenizer, device except Exception as e: print(f"Error loading model: {e}") return None, None, None # Function to correct typos (reverted to single output) def correct_typo(text, model, tokenizer, device): if not text.strip(): #return "Please enter some text." raise gr.Error("Please enter some text💥!", duration=5) if len(text.strip()) > 1024: #return "Shorter the better." raise gr.Error("Shorter the better💥!", duration=5) try: # Prepare input with prefix input_text = "fix: " + text # Tokenize input inputs = tokenizer(input_text, return_tensors="pt", max_length=128, truncation=True) inputs = inputs.to(device) # Generate output with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None), max_length=128, num_beams=4, early_stopping=True ) # Decode the output corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected_text except Exception as e: return f"Error: {str(e)}" # Initialize model and tokenizer model, tokenizer, device = load_model("A3 Model") if model is None: print("Failed to load model. Please check your model and tokenizer paths.") # Function to highlight differences between original and corrected text def highlight_differences(original, corrected): d = difflib.Differ() orig_words = original.split() corr_words = corrected.split() diff = list(d.compare(orig_words, corr_words)) html_parts = [] i = 0 while i < len(diff): if diff[i].startswith(' '): # Unchanged html_parts.append(f'{diff[i][2:]}') elif diff[i].startswith('- '): # Removed if i + 1 < len(diff) and diff[i + 1].startswith('+ '): # Changed word - show correction old_word = diff[i][2:] new_word = diff[i + 1][2:] html_parts.append(f'{old_word}→{new_word}') i += 1 else: # Removed word html_parts.append(f'{diff[i][2:]}') elif diff[i].startswith('+ '): # Added html_parts.append(f'{diff[i][2:]}') i += 1 return f'
{" ".join(html_parts)}
' # Function to process the input for Gradio @spaces.GPU() def process_input(text,model_choice): if model is None: load_model(model_choice) corrected = correct_typo(text, model, tokenizer, device) highlighted = highlight_differences(text, corrected) return corrected, highlighted # Define CSS for Dhivehi font styling css = """ .textbox1 textarea { font-size: 18px !important; font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; line-height: 1.8 !important; direction: rtl !important; } .dhivehi-text { font-size: 18px !important; font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; line-height: 1.8 !important; direction: rtl !important; text-align: right !important; padding: 10px !important; background: transparent !important; /* Make background transparent */ border-radius: 4px !important; color: #ffffff !important; /* White text for dark background */ } /* Style for the highlighted differences */ .dhivehi-diff { font-size: 18px !important; font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important; line-height: 1.8 !important; direction: rtl !important; text-align: right !important; padding: 15px !important; background: transparent !important; /* Make background transparent */ border: 1px solid rgba(255, 255, 255, 0.1) !important; /* Subtle border */ border-radius: 4px !important; margin-top: 10px !important; color: #ffffff !important; /* White text for dark background */ } /* Ensure the highlighted spans have good contrast */ .dhivehi-diff span { padding: 2px 5px !important; border-radius: 3px !important; margin: 0 2px !important; } /* Original text (yellow background) */ .dhivehi-diff span[style*="background-color: #fff3cd"] { background-color: rgba(255, 243, 205, 0.2) !important; color: #ffd700 !important; /* Golden yellow for visibility */ border: 1px solid rgba(255, 243, 205, 0.3) !important; } /* Corrected text (green background) */ .dhivehi-diff span[style*="background-color: #d4edda"] { background-color: rgba(212, 237, 218, 0.2) !important; color: #98ff98 !important; /* Light green for visibility */ border: 1px solid rgba(212, 237, 218, 0.3) !important; } /* Removed text (red background) */ .dhivehi-diff span[style*="background-color: #f8d7da"] { background-color: rgba(248, 215, 218, 0.2) !important; color: #ff6b6b !important; /* Light red for visibility */ border: 1px solid rgba(248, 215, 218, 0.3) !important; } /* Arrow color */ .dhivehi-diff span:contains('→') { color: #ffffff !important; } """