Thanush1 commited on
Commit
aa84e52
ยท
verified ยท
1 Parent(s): a7deefb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py CHANGED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+
6
+ # Model configuration
7
+ model_name = "ai4bharat/IndicBART"
8
+
9
+ # Load tokenizer and model on CPU
10
+ print("Loading IndicBART tokenizer...")
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
12
+
13
+ print("Loading IndicBART model on CPU...")
14
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu")
15
+
16
+ # Language mapping
17
+ LANGUAGE_CODES = {
18
+ "Assamese": "<2as>",
19
+ "Bengali": "<2bn>",
20
+ "English": "<2en>",
21
+ "Gujarati": "<2gu>",
22
+ "Hindi": "<2hi>",
23
+ "Kannada": "<2kn>",
24
+ "Malayalam": "<2ml>",
25
+ "Marathi": "<2mr>",
26
+ "Oriya": "<2or>",
27
+ "Punjabi": "<2pa>",
28
+ "Tamil": "<2ta>",
29
+ "Telugu": "<2te>"
30
+ }
31
+
32
+ @spaces.GPU(duration=60)
33
+ def generate_response(input_text, source_lang, target_lang, task_type, max_length):
34
+ """Generate response using IndicBART"""
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ model_gpu = model.to(device)
37
+
38
+ # Get language codes
39
+ src_code = LANGUAGE_CODES[source_lang]
40
+ tgt_code = LANGUAGE_CODES[target_lang]
41
+
42
+ # Format input based on task type
43
+ if task_type == "Translation":
44
+ formatted_input = f"{input_text} </s> {src_code}"
45
+ decoder_start_token = tgt_code
46
+ elif task_type == "Text Completion":
47
+ # For completion, use target language
48
+ formatted_input = f"{input_text} </s> {tgt_code}"
49
+ decoder_start_token = tgt_code
50
+ else: # Text Generation
51
+ formatted_input = f"{input_text} </s> {src_code}"
52
+ decoder_start_token = tgt_code
53
+
54
+ # Tokenize input
55
+ inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
56
+ inputs = {k: v.to(device) for k, v in inputs.items()}
57
+
58
+ # Get decoder start token id
59
+ decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token)
60
+
61
+ # Generate
62
+ with torch.no_grad():
63
+ outputs = model_gpu.generate(
64
+ **inputs,
65
+ decoder_start_token_id=decoder_start_token_id,
66
+ max_length=max_length,
67
+ num_beams=4,
68
+ early_stopping=True,
69
+ pad_token_id=tokenizer.pad_token_id,
70
+ eos_token_id=tokenizer.eos_token_id,
71
+ use_cache=True
72
+ )
73
+
74
+ # Decode output
75
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
76
+
77
+ # Move model back to CPU
78
+ model_gpu.cpu()
79
+ torch.cuda.empty_cache()
80
+
81
+ return generated_text
82
+
83
+ # Create Gradio interface
84
+ with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft()) as demo:
85
+ gr.Markdown("""
86
+ # ๐Ÿ‡ฎ๐Ÿ‡ณ IndicBART Multilingual Assistant
87
+
88
+ Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation.
89
+
90
+ **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English
91
+ """)
92
+
93
+ with gr.Row():
94
+ with gr.Column(scale=3):
95
+ input_text = gr.Textbox(
96
+ label="Input Text",
97
+ placeholder="Enter text in any supported language...",
98
+ lines=3
99
+ )
100
+
101
+ output_text = gr.Textbox(
102
+ label="Generated Output",
103
+ lines=5,
104
+ interactive=False
105
+ )
106
+
107
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
108
+
109
+ with gr.Column(scale=1):
110
+ task_type = gr.Dropdown(
111
+ choices=["Translation", "Text Completion", "Text Generation"],
112
+ value="Translation",
113
+ label="Task Type"
114
+ )
115
+
116
+ source_lang = gr.Dropdown(
117
+ choices=list(LANGUAGE_CODES.keys()),
118
+ value="English",
119
+ label="Source Language"
120
+ )
121
+
122
+ target_lang = gr.Dropdown(
123
+ choices=list(LANGUAGE_CODES.keys()),
124
+ value="Hindi",
125
+ label="Target Language"
126
+ )
127
+
128
+ max_length = gr.Slider(
129
+ minimum=50,
130
+ maximum=300,
131
+ value=100,
132
+ step=10,
133
+ label="Max Length"
134
+ )
135
+
136
+ # Examples
137
+ gr.Markdown("### ๐Ÿ’ก Try these examples:")
138
+
139
+ examples = [
140
+ ["Hello, how are you?", "English", "Hindi", "Translation", 100],
141
+ ["เคฎเฅˆเค‚ เคเค• เค›เคพเคคเฅเคฐ เคนเฅ‚เค‚", "Hindi", "English", "Translation", 100],
142
+ ["เฆ†เฆฎเฆฟ เฆญเฆพเฆค เฆ–เฆพเฆ‡", "Bengali", "English", "Translation", 100],
143
+ ["เคญเคพเคฐเคค เคเค•", "Hindi", "Hindi", "Text Completion", 150],
144
+ ["The capital of India", "English", "English", "Text Completion", 100]
145
+ ]
146
+
147
+ gr.Examples(
148
+ examples=examples,
149
+ inputs=[input_text, source_lang, target_lang, task_type, max_length],
150
+ outputs=output_text,
151
+ fn=generate_response
152
+ )
153
+
154
+ # Connect generate button
155
+ generate_btn.click(
156
+ generate_response,
157
+ inputs=[input_text, source_lang, target_lang, task_type, max_length],
158
+ outputs=output_text
159
+ )
160
+
161
+ if __name__ == "__main__":
162
+ demo.launch()