alakxender commited on
Commit
e1f2ca6
·
1 Parent(s): e84ce77
Files changed (2) hide show
  1. app.py +48 -0
  2. en_dv_latin.py +63 -0
app.py CHANGED
@@ -4,7 +4,12 @@ from typo_check import css, process_input,MODEL_OPTIONS_TYPO
4
  from title_gen import generate_title, MODEL_OPTIONS_TITLE
5
  from content_gen import generate_content, MODEL_OPTIONS_CONTENT, get_default_prompt
6
  from instruct_dv import generate_response, MODEL_OPTIONS_INSTRUCT
 
7
 
 
 
 
 
8
 
9
  # Create Gradio interface using the latest syntax
10
  with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
@@ -237,6 +242,49 @@ All outputs generated are synthetic, created using fine-tuned models for experim
237
  - The model is experimental and may not always follow instructions perfectly.
238
  """)
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  # Launch the app
241
  if __name__ == "__main__":
242
  #demo.launch(server_name="0.0.0.0", server_port=7811)
 
4
  from title_gen import generate_title, MODEL_OPTIONS_TITLE
5
  from content_gen import generate_content, MODEL_OPTIONS_CONTENT, get_default_prompt
6
  from instruct_dv import generate_response, MODEL_OPTIONS_INSTRUCT
7
+ from en_dv_latin import translate, MODEL_OPTIONS_TRANSLATE
8
 
9
+ def update_textbox_direction(direction):
10
+ # Enable RTL only if the source language is Dhivehi (dv2*)
11
+ is_rtl = direction.startswith("dv2")
12
+ return gr.Textbox(rtl=is_rtl)
13
 
14
  # Create Gradio interface using the latest syntax
15
  with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
 
242
  - The model is experimental and may not always follow instructions perfectly.
243
  """)
244
 
245
+ with gr.Tab("Translation Tasks"):
246
+ gr.Markdown("# <center>Dhivehi Translation</center>")
247
+ gr.Markdown("Select a translation direction and enter text to translate between Dhivehi, English and Latin script.")
248
+ with gr.Row():
249
+ instruction = gr.Dropdown(
250
+ choices=["en2dv:", "dv2en:", "dv2latin:", "latin2dv:"],
251
+ label="Translation Direction",
252
+ value="dv2latin:"
253
+ )
254
+ with gr.Row():
255
+ input_text = gr.Textbox(lines=2, label="Text to Translate", rtl=True, elem_classes="textbox1")
256
+ with gr.Row():
257
+ model_choice = gr.Dropdown(choices=list(MODEL_OPTIONS_TRANSLATE.keys()), value=list(MODEL_OPTIONS_TRANSLATE.keys())[0], label="Model")
258
+ with gr.Row():
259
+ generated_response = gr.Textbox(label="Translated Text", rtl=True, elem_classes="textbox1")
260
+ with gr.Row():
261
+ max_tokens_slider = gr.Slider(10, 128, value=128, label="Max New Tokens")
262
+ num_beams_slider = gr.Slider(1, 10, value=4, step=10, label="Beam Size (num_beams)")
263
+ with gr.Row():
264
+ rep_penalty_slider = gr.Slider(1.0, 1.9, value=1.2, step=0.1, label="Repetition Penalty")
265
+ ngram_slider = gr.Slider(0, 10, value=3, step=1, label="No Repeat Ngram Size")
266
+ generate_btn = gr.Button("Translate")
267
+
268
+ generate_btn.click(
269
+ fn=translate,
270
+ inputs=[instruction, input_text, model_choice,max_tokens_slider, num_beams_slider, rep_penalty_slider, ngram_slider],
271
+ outputs=generated_response
272
+ )
273
+ gr.Examples(
274
+ examples=[
275
+ ["dv2en:", "ދުނިޔޭގެ އެކި ކަންކޮޅުތަކުން 1.4 މިލިއަން މީހުން މައްކާއަށް ޖަމާވެފައި"],
276
+ ["en2dv:", "Concerns over prepayment of GST raised in parliament"],
277
+ ["dv2latin:", "ވައިބާރުވުމުން ކުޅުދުއްފުށީ އެއާޕޯޓަށް ނުޖެއްސިގެން މޯލްޑިވިއަންގެ ބޯޓެއް އެނބުރި މާލެއަށް"],
278
+ ["latin2dv:", "Paakisthaanuge skoolu bahakah dhin hamalaaehgai thin kuhjakaai bodu dhe meehaku maruvehje"],
279
+ ],
280
+ inputs=[instruction, input_text],
281
+ )
282
+ gr.Markdown("""\
283
+ **Notes:**
284
+ - Supports translation between Dhivehi, English and Latin script
285
+ - Model trained on news articles and common phrases
286
+ - Translation quality may vary based on the domain of the text
287
+ """)
288
  # Launch the app
289
  if __name__ == "__main__":
290
  #demo.launch(server_name="0.0.0.0", server_port=7811)
en_dv_latin.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
5
+ import spaces
6
+
7
+
8
+ # Available models
9
+ MODEL_OPTIONS_TRANSLATE = {
10
+ "T1DV Model": "alakxender/flan-t5-base-dhivehi-en-latin",
11
+ }
12
+
13
+ # Cache for loaded models/tokenizers
14
+ MODEL_CACHE = {}
15
+
16
+ def get_model_and_tokenizer(model_dir):
17
+ if model_dir not in MODEL_CACHE:
18
+ print(f"Loading model: {model_dir}")
19
+ tokenizer = T5Tokenizer.from_pretrained(model_dir)
20
+ model = T5ForConditionalGeneration.from_pretrained(model_dir)
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ print(f"Moving model to device: {device}")
23
+ model.to(device)
24
+ MODEL_CACHE[model_dir] = (tokenizer, model)
25
+ return MODEL_CACHE[model_dir]
26
+
27
+ max_input_length = 128
28
+ max_output_length = 128
29
+
30
+ @spaces.GPU()
31
+ def translate(instruction, input_text, model_choice, max_new_tokens=128, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=3):
32
+ model_dir = MODEL_OPTIONS_TRANSLATE[model_choice]
33
+ tokenizer, model = get_model_and_tokenizer(model_dir)
34
+
35
+ combined_input = f"{instruction.strip()} {input_text.strip()}" if input_text else instruction.strip()
36
+ inputs = tokenizer(
37
+ combined_input,
38
+ return_tensors="pt",
39
+ truncation=True,
40
+ max_length=max_input_length
41
+ )
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ inputs = {k: v.to(device) for k, v in inputs.items()}
44
+
45
+ gen_kwargs = {
46
+ **inputs,
47
+ "max_length":max_new_tokens,
48
+ "min_length":10,
49
+ "num_beams":num_beams,
50
+ "early_stopping":True,
51
+ "no_repeat_ngram_size":no_repeat_ngram_size,
52
+ "repetition_penalty":repetition_penalty,
53
+ "do_sample":False,
54
+ "pad_token_id":tokenizer.pad_token_id,
55
+ "eos_token_id":tokenizer.eos_token_id
56
+ }
57
+
58
+
59
+ with torch.no_grad():
60
+ outputs = model.generate(**gen_kwargs)
61
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+
63
+ return decoded_output