Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e1f2ca6
1
Parent(s):
e84ce77
- app.py +48 -0
- en_dv_latin.py +63 -0
app.py
CHANGED
@@ -4,7 +4,12 @@ from typo_check import css, process_input,MODEL_OPTIONS_TYPO
|
|
4 |
from title_gen import generate_title, MODEL_OPTIONS_TITLE
|
5 |
from content_gen import generate_content, MODEL_OPTIONS_CONTENT, get_default_prompt
|
6 |
from instruct_dv import generate_response, MODEL_OPTIONS_INSTRUCT
|
|
|
7 |
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Create Gradio interface using the latest syntax
|
10 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
@@ -237,6 +242,49 @@ All outputs generated are synthetic, created using fine-tuned models for experim
|
|
237 |
- The model is experimental and may not always follow instructions perfectly.
|
238 |
""")
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
# Launch the app
|
241 |
if __name__ == "__main__":
|
242 |
#demo.launch(server_name="0.0.0.0", server_port=7811)
|
|
|
4 |
from title_gen import generate_title, MODEL_OPTIONS_TITLE
|
5 |
from content_gen import generate_content, MODEL_OPTIONS_CONTENT, get_default_prompt
|
6 |
from instruct_dv import generate_response, MODEL_OPTIONS_INSTRUCT
|
7 |
+
from en_dv_latin import translate, MODEL_OPTIONS_TRANSLATE
|
8 |
|
9 |
+
def update_textbox_direction(direction):
|
10 |
+
# Enable RTL only if the source language is Dhivehi (dv2*)
|
11 |
+
is_rtl = direction.startswith("dv2")
|
12 |
+
return gr.Textbox(rtl=is_rtl)
|
13 |
|
14 |
# Create Gradio interface using the latest syntax
|
15 |
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
|
242 |
- The model is experimental and may not always follow instructions perfectly.
|
243 |
""")
|
244 |
|
245 |
+
with gr.Tab("Translation Tasks"):
|
246 |
+
gr.Markdown("# <center>Dhivehi Translation</center>")
|
247 |
+
gr.Markdown("Select a translation direction and enter text to translate between Dhivehi, English and Latin script.")
|
248 |
+
with gr.Row():
|
249 |
+
instruction = gr.Dropdown(
|
250 |
+
choices=["en2dv:", "dv2en:", "dv2latin:", "latin2dv:"],
|
251 |
+
label="Translation Direction",
|
252 |
+
value="dv2latin:"
|
253 |
+
)
|
254 |
+
with gr.Row():
|
255 |
+
input_text = gr.Textbox(lines=2, label="Text to Translate", rtl=True, elem_classes="textbox1")
|
256 |
+
with gr.Row():
|
257 |
+
model_choice = gr.Dropdown(choices=list(MODEL_OPTIONS_TRANSLATE.keys()), value=list(MODEL_OPTIONS_TRANSLATE.keys())[0], label="Model")
|
258 |
+
with gr.Row():
|
259 |
+
generated_response = gr.Textbox(label="Translated Text", rtl=True, elem_classes="textbox1")
|
260 |
+
with gr.Row():
|
261 |
+
max_tokens_slider = gr.Slider(10, 128, value=128, label="Max New Tokens")
|
262 |
+
num_beams_slider = gr.Slider(1, 10, value=4, step=10, label="Beam Size (num_beams)")
|
263 |
+
with gr.Row():
|
264 |
+
rep_penalty_slider = gr.Slider(1.0, 1.9, value=1.2, step=0.1, label="Repetition Penalty")
|
265 |
+
ngram_slider = gr.Slider(0, 10, value=3, step=1, label="No Repeat Ngram Size")
|
266 |
+
generate_btn = gr.Button("Translate")
|
267 |
+
|
268 |
+
generate_btn.click(
|
269 |
+
fn=translate,
|
270 |
+
inputs=[instruction, input_text, model_choice,max_tokens_slider, num_beams_slider, rep_penalty_slider, ngram_slider],
|
271 |
+
outputs=generated_response
|
272 |
+
)
|
273 |
+
gr.Examples(
|
274 |
+
examples=[
|
275 |
+
["dv2en:", "ދުނިޔޭގެ އެކި ކަންކޮޅުތަކުން 1.4 މިލިއަން މީހުން މައްކާއަށް ޖަމާވެފައި"],
|
276 |
+
["en2dv:", "Concerns over prepayment of GST raised in parliament"],
|
277 |
+
["dv2latin:", "ވައިބާރުވުމުން ކުޅުދުއްފުށީ އެއާޕޯޓަށް ނުޖެއްސިގެން މޯލްޑިވިއަންގެ ބޯޓެއް އެނބުރި މާލެއަށް"],
|
278 |
+
["latin2dv:", "Paakisthaanuge skoolu bahakah dhin hamalaaehgai thin kuhjakaai bodu dhe meehaku maruvehje"],
|
279 |
+
],
|
280 |
+
inputs=[instruction, input_text],
|
281 |
+
)
|
282 |
+
gr.Markdown("""\
|
283 |
+
**Notes:**
|
284 |
+
- Supports translation between Dhivehi, English and Latin script
|
285 |
+
- Model trained on news articles and common phrases
|
286 |
+
- Translation quality may vary based on the domain of the text
|
287 |
+
""")
|
288 |
# Launch the app
|
289 |
if __name__ == "__main__":
|
290 |
#demo.launch(server_name="0.0.0.0", server_port=7811)
|
en_dv_latin.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
5 |
+
import spaces
|
6 |
+
|
7 |
+
|
8 |
+
# Available models
|
9 |
+
MODEL_OPTIONS_TRANSLATE = {
|
10 |
+
"T1DV Model": "alakxender/flan-t5-base-dhivehi-en-latin",
|
11 |
+
}
|
12 |
+
|
13 |
+
# Cache for loaded models/tokenizers
|
14 |
+
MODEL_CACHE = {}
|
15 |
+
|
16 |
+
def get_model_and_tokenizer(model_dir):
|
17 |
+
if model_dir not in MODEL_CACHE:
|
18 |
+
print(f"Loading model: {model_dir}")
|
19 |
+
tokenizer = T5Tokenizer.from_pretrained(model_dir)
|
20 |
+
model = T5ForConditionalGeneration.from_pretrained(model_dir)
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
print(f"Moving model to device: {device}")
|
23 |
+
model.to(device)
|
24 |
+
MODEL_CACHE[model_dir] = (tokenizer, model)
|
25 |
+
return MODEL_CACHE[model_dir]
|
26 |
+
|
27 |
+
max_input_length = 128
|
28 |
+
max_output_length = 128
|
29 |
+
|
30 |
+
@spaces.GPU()
|
31 |
+
def translate(instruction, input_text, model_choice, max_new_tokens=128, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=3):
|
32 |
+
model_dir = MODEL_OPTIONS_TRANSLATE[model_choice]
|
33 |
+
tokenizer, model = get_model_and_tokenizer(model_dir)
|
34 |
+
|
35 |
+
combined_input = f"{instruction.strip()} {input_text.strip()}" if input_text else instruction.strip()
|
36 |
+
inputs = tokenizer(
|
37 |
+
combined_input,
|
38 |
+
return_tensors="pt",
|
39 |
+
truncation=True,
|
40 |
+
max_length=max_input_length
|
41 |
+
)
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
44 |
+
|
45 |
+
gen_kwargs = {
|
46 |
+
**inputs,
|
47 |
+
"max_length":max_new_tokens,
|
48 |
+
"min_length":10,
|
49 |
+
"num_beams":num_beams,
|
50 |
+
"early_stopping":True,
|
51 |
+
"no_repeat_ngram_size":no_repeat_ngram_size,
|
52 |
+
"repetition_penalty":repetition_penalty,
|
53 |
+
"do_sample":False,
|
54 |
+
"pad_token_id":tokenizer.pad_token_id,
|
55 |
+
"eos_token_id":tokenizer.eos_token_id
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model.generate(**gen_kwargs)
|
61 |
+
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
62 |
+
|
63 |
+
return decoded_output
|