24Sureshkumar commited on
Commit
1437058
·
verified ·
1 Parent(s): 514d8bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -66
app.py CHANGED
@@ -1,13 +1,15 @@
1
- import streamlit as st
2
  import os
 
3
  from dotenv import load_dotenv
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from diffusers import StableDiffusionPipeline
6
  import torch
7
 
 
 
8
  HF_TOKEN = os.getenv("HF_TOKEN")
9
 
10
- # Set page config
11
  st.set_page_config(
12
  page_title="Tamil Creative Studio",
13
  page_icon="🇮🇳",
@@ -22,30 +24,27 @@ def load_css(file_name):
22
 
23
  @st.cache_resource(show_spinner=False)
24
  def load_all_models():
25
- # Load translation model
26
  trans_tokenizer = AutoTokenizer.from_pretrained(
27
- "ai4bharat/indictrans2-ta-en-dist-200M", token=HF_TOKEN
 
28
  )
29
  trans_model = AutoModelForSeq2SeqLM.from_pretrained(
30
- "ai4bharat/indictrans2-ta-en-dist-200M", token=HF_TOKEN
 
31
  )
32
-
33
  # Load text generation model
34
- text_gen = pipeline(
35
- "text-generation",
36
- model="gpt2",
37
- device=-1,
38
- token=HF_TOKEN
39
- )
40
-
41
  # Load image generation model
42
  img_pipe = StableDiffusionPipeline.from_pretrained(
43
  "stabilityai/stable-diffusion-2-base",
 
44
  torch_dtype=torch.float32,
45
- safety_checker=None,
46
- token=HF_TOKEN
47
  ).to("cpu")
48
-
49
  return trans_tokenizer, trans_model, text_gen, img_pipe
50
 
51
  def translate_tamil(text, tokenizer, model):
@@ -56,74 +55,60 @@ def translate_tamil(text, tokenizer, model):
56
  truncation=True,
57
  max_length=128
58
  )
59
-
60
  generated = model.generate(
61
  **inputs,
62
  max_length=150,
63
  num_beams=5,
64
  early_stopping=True
65
  )
66
-
67
  return tokenizer.batch_decode(
68
- generated,
69
  skip_special_tokens=True,
70
  clean_up_tokenization_spaces=True
71
  )[0]
72
 
73
  def main():
74
  load_css("style.css")
75
-
76
- st.markdown(
77
- """
78
  <div class="header">
79
  <h1>🌐 தமிழ் → English → Creative Studio</h1>
80
  <p>Translate Tamil text and generate creative content</p>
81
  </div>
82
- """,
83
- unsafe_allow_html=True
84
- )
85
-
86
  tokenizer, model, text_gen, img_pipe = load_all_models()
87
-
88
- with st.container():
89
- tamil_text = st.text_area(
90
- "**தமிழ் உரை:**",
91
- height=150,
92
- placeholder="உங்கள் உரையை இங்கே உள்ளிடவும்...",
93
- key="tamil_input"
94
- )
95
-
96
- col1, col2 = st.columns([1, 3])
97
- with col1:
98
- if st.button("**உருவாக்கு**", type="primary", use_container_width=True):
99
- if not tamil_text.strip():
100
- st.warning("தயவு செய்து உரையை உள்ளிடவும்.")
101
- st.stop()
102
-
103
- with st.spinner("மொழிபெயர்க்கிறது..."):
104
- eng = translate_tamil(tamil_text, tokenizer, model)
105
-
106
- with st.expander("**🔤 Translation**", expanded=True):
107
- st.success(eng)
108
-
109
- with st.spinner("உரை உருவாக்குதல்..."):
110
- creative = text_gen(
111
- f"Create a creative description about: {eng}",
112
- max_length=80,
113
- num_return_sequences=1
114
- )[0]["generated_text"]
115
-
116
- st.info("**📝 Creative Text:**")
117
- st.write(creative)
118
-
119
- with st.spinner("படத்தை உருவாக்குதல்..."):
120
- img = img_pipe(
121
- eng,
122
- num_inference_steps=40,
123
- guidance_scale=8.5
124
- ).images[0]
125
-
126
- st.image(img, caption="**🎨 Generated Image**", use_column_width=True)
127
 
128
  if __name__ == "__main__":
129
  main()
 
 
1
  import os
2
+ import streamlit as st
3
  from dotenv import load_dotenv
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from diffusers import StableDiffusionPipeline
6
  import torch
7
 
8
+ # Load environment variables
9
+ load_dotenv()
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
 
12
+ # Set Streamlit page config
13
  st.set_page_config(
14
  page_title="Tamil Creative Studio",
15
  page_icon="🇮🇳",
 
24
 
25
  @st.cache_resource(show_spinner=False)
26
  def load_all_models():
27
+ # Load translation model (private)
28
  trans_tokenizer = AutoTokenizer.from_pretrained(
29
+ "ai4bharat/indictrans2-ta-en-dist-200M",
30
+ use_auth_token=HF_TOKEN
31
  )
32
  trans_model = AutoModelForSeq2SeqLM.from_pretrained(
33
+ "ai4bharat/indictrans2-ta-en-dist-200M",
34
+ use_auth_token=HF_TOKEN
35
  )
36
+
37
  # Load text generation model
38
+ text_gen = pipeline("text-generation", model="gpt2", device=-1)
39
+
 
 
 
 
 
40
  # Load image generation model
41
  img_pipe = StableDiffusionPipeline.from_pretrained(
42
  "stabilityai/stable-diffusion-2-base",
43
+ use_auth_token=HF_TOKEN,
44
  torch_dtype=torch.float32,
45
+ safety_checker=None
 
46
  ).to("cpu")
47
+
48
  return trans_tokenizer, trans_model, text_gen, img_pipe
49
 
50
  def translate_tamil(text, tokenizer, model):
 
55
  truncation=True,
56
  max_length=128
57
  )
58
+
59
  generated = model.generate(
60
  **inputs,
61
  max_length=150,
62
  num_beams=5,
63
  early_stopping=True
64
  )
65
+
66
  return tokenizer.batch_decode(
67
+ generated,
68
  skip_special_tokens=True,
69
  clean_up_tokenization_spaces=True
70
  )[0]
71
 
72
  def main():
73
  load_css("style.css")
74
+
75
+ # Header
76
+ st.markdown("""
77
  <div class="header">
78
  <h1>🌐 தமிழ் → English → Creative Studio</h1>
79
  <p>Translate Tamil text and generate creative content</p>
80
  </div>
81
+ """, unsafe_allow_html=True)
82
+
 
 
83
  tokenizer, model, text_gen, img_pipe = load_all_models()
84
+
85
+ tamil_text = st.text_area("**தமிழ் உரை:**", height=150, placeholder="உங்கள் உரையை இங்கே உள்ளிடவும்...")
86
+
87
+ if st.button("**உருவாக்கு**", type="primary", use_container_width=True):
88
+ if not tamil_text.strip():
89
+ st.warning("தயவு செய்து உரையை உள்ளிடவும்.")
90
+ st.stop()
91
+
92
+ with st.spinner("மொழிபெயர்க்கிறது..."):
93
+ eng = translate_tamil(tamil_text, tokenizer, model)
94
+
95
+ with st.expander("🔤 Translation", expanded=True):
96
+ st.success(eng)
97
+
98
+ with st.spinner("உரை உருவ��க்குதல்..."):
99
+ creative = text_gen(
100
+ f"Create a creative description about: {eng}",
101
+ max_length=80,
102
+ num_return_sequences=1
103
+ )[0]["generated_text"]
104
+
105
+ st.info("📝 Creative Text:")
106
+ st.write(creative)
107
+
108
+ with st.spinner("படத்தை உருவாக்குதல்..."):
109
+ img = img_pipe(eng, num_inference_steps=40, guidance_scale=8.5).images[0]
110
+
111
+ st.image(img, caption="🎨 Generated Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if __name__ == "__main__":
114
  main()