#!/usr/bin/env python # Gradio app for Dhivehi typo correction import difflib from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import gradio as gr import spaces # Available models MODEL_OPTIONS_TYPO = { "A3 Model": "alakxender/t5-dhivehi-typo-corrector-asr", "XS Model": "alakxender/dhivehi-quick-spell-check-t5" } # Function to load model and tokenizer def load_model(model_choice): print("Loading model and tokenizer...") try: selected_model = MODEL_OPTIONS_TYPO[model_choice] tokenizer = AutoTokenizer.from_pretrained(selected_model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForSeq2SeqLM.from_pretrained(selected_model) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) print(f"Model loaded successfully on {device}") return model, tokenizer, device except Exception as e: print(f"Error loading model: {e}") return None, None, None # Function to correct typos (reverted to single output) def correct_typo(text, model, tokenizer, device): if not text.strip(): #return "Please enter some text." raise gr.Error("Please enter some text💥!", duration=5) if len(text.strip()) > 1024: #return "Shorter the better." raise gr.Error("Shorter the better💥!", duration=5) try: # Prepare input with prefix input_text = "fix: " + text # Tokenize input inputs = tokenizer(input_text, return_tensors="pt", max_length=128, truncation=True) inputs = inputs.to(device) # Generate output with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None), max_length=128, num_beams=4, early_stopping=True ) # Decode the output corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return corrected_text except Exception as e: return f"Error: {str(e)}" # Initialize model and tokenizer model, tokenizer, device = load_model("A3 Model") if model is None: print("Failed to load model. Please check your model and tokenizer paths.") # Function to highlight differences between original and corrected text def highlight_differences(original, corrected): d = difflib.Differ() orig_words = original.split() corr_words = corrected.split() diff = list(d.compare(orig_words, corr_words)) html_parts = [] i = 0 while i < len(diff): if diff[i].startswith(' '): # Unchanged html_parts.append(f'{diff[i][2:]}') elif diff[i].startswith('- '): # Removed if i + 1 < len(diff) and diff[i + 1].startswith('+ '): # Changed word - show correction old_word = diff[i][2:] new_word = diff[i + 1][2:] html_parts.append(f'{old_word}→{new_word}') i += 1 else: # Removed word html_parts.append(f'{diff[i][2:]}') elif diff[i].startswith('+ '): # Added html_parts.append(f'{diff[i][2:]}') i += 1 return f'