Thanush1 commited on
Commit
4320235
·
verified ·
1 Parent(s): 18ec4a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -63
app.py CHANGED
@@ -12,7 +12,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_f
12
  print("Loading IndicBART model on CPU...")
13
  model = AutoModelForSeq2SeqLM.from_pretrained(
14
  model_name,
15
- torch_dtype=torch.float32, # Use float32 for better CPU performance
16
  device_map="cpu"
17
  )
18
 
@@ -35,50 +35,65 @@ LANGUAGE_CODES = {
35
  def generate_response(input_text, source_lang, target_lang, task_type, max_length):
36
  """Generate response using IndicBART on CPU"""
37
 
38
- # Get language codes
39
- src_code = LANGUAGE_CODES[source_lang]
40
- tgt_code = LANGUAGE_CODES[target_lang]
41
-
42
- # Format input based on task type
43
- if task_type == "Translation":
44
- formatted_input = f"{input_text} </s> {src_code}"
45
- decoder_start_token = tgt_code
46
- elif task_type == "Text Completion":
47
- # For completion, use target language
48
- formatted_input = f"{input_text} </s> {tgt_code}"
49
- decoder_start_token = tgt_code
50
- else: # Text Generation
51
- formatted_input = f"{input_text} </s> {src_code}"
52
- decoder_start_token = tgt_code
53
-
54
- # Tokenize input (keep on CPU)
55
- inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
56
-
57
- # Get decoder start token id
58
  try:
59
- decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token)
60
- except:
61
- # Fallback if the method doesn't exist
62
- decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)
63
-
64
- # Generate on CPU
65
- with torch.no_grad():
66
- outputs = model.generate(
67
- **inputs,
68
- decoder_start_token_id=decoder_start_token_id,
69
- max_length=max_length,
70
- num_beams=2, # Reduced for faster CPU inference
71
- early_stopping=True,
72
- pad_token_id=tokenizer.pad_token_id,
73
- eos_token_id=tokenizer.eos_token_id,
74
- use_cache=True,
75
- do_sample=False # Deterministic for CPU
 
 
 
 
 
 
76
  )
77
-
78
- # Decode output
79
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
80
-
81
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Create Gradio interface
84
  with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo:
@@ -88,8 +103,6 @@ with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Sof
88
  Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation.
89
 
90
  **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English
91
-
92
- *Note: Running on CPU - responses may take longer than GPU version.*
93
  """)
94
 
95
  with gr.Row():
@@ -131,34 +144,39 @@ with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Sof
131
 
132
  max_length = gr.Slider(
133
  minimum=20,
134
- maximum=200, # Reduced for faster CPU processing
135
  value=80,
136
  step=10,
137
  label="Max Length"
138
  )
139
 
140
- # Examples
141
  gr.Markdown("### 💡 Try these examples:")
142
 
143
- examples = [
144
- ["Hello, how are you?", "English", "Hindi", "Translation", 80],
145
- ["मैं एक छात्र हूं", "Hindi", "English", "Translation", 80],
146
- ["আমি ভাত খাই", "Bengali", "English", "Translation", 80],
147
- ["भारत एक", "Hindi", "Hindi", "Text Completion", 100],
148
- ["The capital of India", "English", "English", "Text Completion", 80]
149
- ]
150
-
151
- gr.Examples(
152
- examples=examples,
153
- inputs=[input_text, source_lang, target_lang, task_type, max_length],
154
- outputs=output_text,
155
- fn=generate_response
156
- )
157
 
158
  # Event handlers
159
  def clear_fields():
160
  return "", ""
161
 
 
 
 
 
 
 
 
 
 
162
  # Connect buttons
163
  generate_btn.click(
164
  generate_response,
@@ -170,9 +188,28 @@ with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Sof
170
  clear_fields,
171
  outputs=[input_text, output_text]
172
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  if __name__ == "__main__":
174
  demo.launch(
175
  share=True,
176
- server_port=7860,
177
- show_error=True
 
 
178
  )
 
12
  print("Loading IndicBART model on CPU...")
13
  model = AutoModelForSeq2SeqLM.from_pretrained(
14
  model_name,
15
+ torch_dtype=torch.float32,
16
  device_map="cpu"
17
  )
18
 
 
35
  def generate_response(input_text, source_lang, target_lang, task_type, max_length):
36
  """Generate response using IndicBART on CPU"""
37
 
38
+ if not input_text.strip():
39
+ return "Please enter some text to process."
40
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  try:
42
+ # Get language codes
43
+ src_code = LANGUAGE_CODES[source_lang]
44
+ tgt_code = LANGUAGE_CODES[target_lang]
45
+
46
+ # Format input based on task type
47
+ if task_type == "Translation":
48
+ formatted_input = f"{input_text} </s> {src_code}"
49
+ decoder_start_token = tgt_code
50
+ elif task_type == "Text Completion":
51
+ formatted_input = f"{input_text} </s> {tgt_code}"
52
+ decoder_start_token = tgt_code
53
+ else: # Text Generation
54
+ formatted_input = f"{input_text} </s> {src_code}"
55
+ decoder_start_token = tgt_code
56
+
57
+ # Tokenize input - KEY FIX: Explicitly set return_token_type_ids=False
58
+ inputs = tokenizer(
59
+ formatted_input,
60
+ return_tensors="pt",
61
+ padding=True,
62
+ truncation=True,
63
+ max_length=512,
64
+ return_token_type_ids=False # This prevents the error
65
  )
66
+
67
+ # Alternative fix: Remove token_type_ids if present
68
+ if 'token_type_ids' in inputs:
69
+ del inputs['token_type_ids']
70
+
71
+ # Get decoder start token id
72
+ try:
73
+ decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token)
74
+ except:
75
+ decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)
76
+
77
+ # Generate on CPU
78
+ with torch.no_grad():
79
+ outputs = model.generate(
80
+ **inputs,
81
+ decoder_start_token_id=decoder_start_token_id,
82
+ max_length=max_length,
83
+ num_beams=2,
84
+ early_stopping=True,
85
+ pad_token_id=tokenizer.pad_token_id,
86
+ eos_token_id=tokenizer.eos_token_id,
87
+ use_cache=True,
88
+ do_sample=False
89
+ )
90
+
91
+ # Decode output
92
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
93
+ return generated_text
94
+
95
+ except Exception as e:
96
+ return f"Error generating response: {str(e)}"
97
 
98
  # Create Gradio interface
99
  with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo:
 
103
  Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation.
104
 
105
  **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English
 
 
106
  """)
107
 
108
  with gr.Row():
 
144
 
145
  max_length = gr.Slider(
146
  minimum=20,
147
+ maximum=200,
148
  value=80,
149
  step=10,
150
  label="Max Length"
151
  )
152
 
153
+ # Simplified examples to avoid caching issues
154
  gr.Markdown("### 💡 Try these examples:")
155
 
156
+ with gr.Row():
157
+ with gr.Column():
158
+ gr.Markdown("**English to Hindi**")
159
+ example1_btn = gr.Button("Hello, how are you?")
160
+ with gr.Column():
161
+ gr.Markdown("**Hindi to English**")
162
+ example2_btn = gr.Button("मैं एक छात्र हूं")
163
+ with gr.Column():
164
+ gr.Markdown("**Bengali to English**")
165
+ example3_btn = gr.Button("আমি ভাত খাই")
 
 
 
 
166
 
167
  # Event handlers
168
  def clear_fields():
169
  return "", ""
170
 
171
+ def set_example1():
172
+ return "Hello, how are you?", "English", "Hindi", "Translation"
173
+
174
+ def set_example2():
175
+ return "मैं एक छात्र हूं", "Hindi", "English", "Translation"
176
+
177
+ def set_example3():
178
+ return "আমি ভাত খাই", "Bengali", "English", "Translation"
179
+
180
  # Connect buttons
181
  generate_btn.click(
182
  generate_response,
 
188
  clear_fields,
189
  outputs=[input_text, output_text]
190
  )
191
+
192
+ example1_btn.click(
193
+ set_example1,
194
+ outputs=[input_text, source_lang, target_lang, task_type]
195
+ )
196
+
197
+ example2_btn.click(
198
+ set_example2,
199
+ outputs=[input_text, source_lang, target_lang, task_type]
200
+ )
201
+
202
+ example3_btn.click(
203
+ set_example3,
204
+ outputs=[input_text, source_lang, target_lang, task_type]
205
+ )
206
+
207
+ # Launch with all fixes applied
208
  if __name__ == "__main__":
209
  demo.launch(
210
  share=True,
211
+ ssr_mode=False, # Disable SSR
212
+ cache_examples=False, # Disable example caching - KEY FIX
213
+ show_error=True,
214
+ enable_queue=False # Disable queue to avoid startup issues
215
  )