Thanush1 commited on
Commit
a31db9e
·
verified ·
1 Parent(s): 28f6524

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -30
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import spaces
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@@ -11,7 +10,11 @@ print("Loading IndicBART tokenizer...")
11
  tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
12
 
13
  print("Loading IndicBART model on CPU...")
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu")
 
 
 
 
15
 
16
  # Language mapping
17
  LANGUAGE_CODES = {
@@ -29,11 +32,8 @@ LANGUAGE_CODES = {
29
  "Telugu": "<2te>"
30
  }
31
 
32
- @spaces.GPU(duration=60)
33
  def generate_response(input_text, source_lang, target_lang, task_type, max_length):
34
- """Generate response using IndicBART"""
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- model_gpu = model.to(device)
37
 
38
  # Get language codes
39
  src_code = LANGUAGE_CODES[source_lang]
@@ -51,43 +51,45 @@ def generate_response(input_text, source_lang, target_lang, task_type, max_lengt
51
  formatted_input = f"{input_text} </s> {src_code}"
52
  decoder_start_token = tgt_code
53
 
54
- # Tokenize input
55
  inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
56
- inputs = {k: v.to(device) for k, v in inputs.items()}
57
 
58
  # Get decoder start token id
59
- decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token)
 
 
 
 
60
 
61
- # Generate
62
  with torch.no_grad():
63
- outputs = model_gpu.generate(
64
  **inputs,
65
  decoder_start_token_id=decoder_start_token_id,
66
  max_length=max_length,
67
- num_beams=4,
68
  early_stopping=True,
69
  pad_token_id=tokenizer.pad_token_id,
70
  eos_token_id=tokenizer.eos_token_id,
71
- use_cache=True
 
72
  )
73
 
74
  # Decode output
75
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
76
 
77
- # Move model back to CPU
78
- model_gpu.cpu()
79
- torch.cuda.empty_cache()
80
-
81
  return generated_text
82
 
83
  # Create Gradio interface
84
- with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft()) as demo:
85
  gr.Markdown("""
86
- # 🇮🇳 IndicBART Multilingual Assistant
87
 
88
  Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation.
89
 
90
  **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English
 
 
91
  """)
92
 
93
  with gr.Row():
@@ -104,7 +106,9 @@ with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft())
104
  interactive=False
105
  )
106
 
107
- generate_btn = gr.Button("Generate", variant="primary", size="lg")
 
 
108
 
109
  with gr.Column(scale=1):
110
  task_type = gr.Dropdown(
@@ -126,9 +130,9 @@ with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft())
126
  )
127
 
128
  max_length = gr.Slider(
129
- minimum=50,
130
- maximum=300,
131
- value=100,
132
  step=10,
133
  label="Max Length"
134
  )
@@ -137,11 +141,11 @@ with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft())
137
  gr.Markdown("### 💡 Try these examples:")
138
 
139
  examples = [
140
- ["Hello, how are you?", "English", "Hindi", "Translation", 100],
141
- ["मैं एक छात्र हूं", "Hindi", "English", "Translation", 100],
142
- ["আমি ভাত খাই", "Bengali", "English", "Translation", 100],
143
- ["भारत एक", "Hindi", "Hindi", "Text Completion", 150],
144
- ["The capital of India", "English", "English", "Text Completion", 100]
145
  ]
146
 
147
  gr.Examples(
@@ -151,12 +155,21 @@ with gr.Blocks(title="IndicBART Multilingual Assistant", theme=gr.themes.Soft())
151
  fn=generate_response
152
  )
153
 
154
- # Connect generate button
 
 
 
 
155
  generate_btn.click(
156
  generate_response,
157
  inputs=[input_text, source_lang, target_lang, task_type, max_length],
158
  outputs=output_text
159
  )
 
 
 
 
 
160
 
161
  if __name__ == "__main__":
162
- demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
10
  tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
11
 
12
  print("Loading IndicBART model on CPU...")
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.float32, # Use float32 for better CPU performance
16
+ device_map="cpu"
17
+ )
18
 
19
  # Language mapping
20
  LANGUAGE_CODES = {
 
32
  "Telugu": "<2te>"
33
  }
34
 
 
35
  def generate_response(input_text, source_lang, target_lang, task_type, max_length):
36
+ """Generate response using IndicBART on CPU"""
 
 
37
 
38
  # Get language codes
39
  src_code = LANGUAGE_CODES[source_lang]
 
51
  formatted_input = f"{input_text} </s> {src_code}"
52
  decoder_start_token = tgt_code
53
 
54
+ # Tokenize input (keep on CPU)
55
  inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
56
 
57
  # Get decoder start token id
58
+ try:
59
+ decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token)
60
+ except:
61
+ # Fallback if the method doesn't exist
62
+ decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)
63
 
64
+ # Generate on CPU
65
  with torch.no_grad():
66
+ outputs = model.generate(
67
  **inputs,
68
  decoder_start_token_id=decoder_start_token_id,
69
  max_length=max_length,
70
+ num_beams=2, # Reduced for faster CPU inference
71
  early_stopping=True,
72
  pad_token_id=tokenizer.pad_token_id,
73
  eos_token_id=tokenizer.eos_token_id,
74
+ use_cache=True,
75
+ do_sample=False # Deterministic for CPU
76
  )
77
 
78
  # Decode output
79
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
80
 
 
 
 
 
81
  return generated_text
82
 
83
  # Create Gradio interface
84
+ with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo:
85
  gr.Markdown("""
86
+ # 🇮🇳 IndicBART Multilingual Assistant (CPU Version)
87
 
88
  Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation.
89
 
90
  **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English
91
+
92
+ *Note: Running on CPU - responses may take longer than GPU version.*
93
  """)
94
 
95
  with gr.Row():
 
106
  interactive=False
107
  )
108
 
109
+ with gr.Row():
110
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
111
+ clear_btn = gr.Button("Clear", variant="secondary")
112
 
113
  with gr.Column(scale=1):
114
  task_type = gr.Dropdown(
 
130
  )
131
 
132
  max_length = gr.Slider(
133
+ minimum=20,
134
+ maximum=200, # Reduced for faster CPU processing
135
+ value=80,
136
  step=10,
137
  label="Max Length"
138
  )
 
141
  gr.Markdown("### 💡 Try these examples:")
142
 
143
  examples = [
144
+ ["Hello, how are you?", "English", "Hindi", "Translation", 80],
145
+ ["मैं एक छात्र हूं", "Hindi", "English", "Translation", 80],
146
+ ["আমি ভাত খাই", "Bengali", "English", "Translation", 80],
147
+ ["भारत एक", "Hindi", "Hindi", "Text Completion", 100],
148
+ ["The capital of India", "English", "English", "Text Completion", 80]
149
  ]
150
 
151
  gr.Examples(
 
155
  fn=generate_response
156
  )
157
 
158
+ # Event handlers
159
+ def clear_fields():
160
+ return "", ""
161
+
162
+ # Connect buttons
163
  generate_btn.click(
164
  generate_response,
165
  inputs=[input_text, source_lang, target_lang, task_type, max_length],
166
  outputs=output_text
167
  )
168
+
169
+ clear_btn.click(
170
+ clear_fields,
171
+ outputs=[input_text, output_text]
172
+ )
173
 
174
  if __name__ == "__main__":
175
+ demo.launch(share=True) # Added share=True for easier access