24Sureshkumar commited on
Commit
4d86cc2
·
verified ·
1 Parent(s): 5ca4bcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -76
app.py CHANGED
@@ -1,114 +1,74 @@
1
- import os
2
  import streamlit as st
3
- from dotenv import load_dotenv
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from diffusers import StableDiffusionPipeline
6
  import torch
7
 
8
-
9
- load_dotenv()
10
- # Set Streamlit page config
11
  st.set_page_config(
12
  page_title="Tamil Creative Studio",
13
  page_icon="🇮🇳",
14
  layout="centered",
15
- initial_sidebar_state="collapsed"
16
  )
17
 
18
- # Load custom CSS
19
- def load_css(file_name):
20
- with open(file_name, "r") as f:
21
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @st.cache_resource(show_spinner=False)
24
  def load_all_models():
25
- HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token from Hugging Face Spaces secret
26
-
27
- # Load translation model (private)
28
- trans_tokenizer = AutoTokenizer.from_pretrained(
29
- "ai4bharat/indictrans2-ta-en-dist-200M",
30
- token=HF_TOKEN
31
- )
32
- trans_model = AutoModelForSeq2SeqLM.from_pretrained(
33
- "ai4bharat/indictrans2-ta-en-dist-200M",
34
- token=HF_TOKEN
35
- )
36
-
37
- # Load text generation model
38
  text_gen = pipeline("text-generation", model="gpt2", device=-1)
39
-
40
- # Load image generation model
41
  img_pipe = StableDiffusionPipeline.from_pretrained(
42
  "stabilityai/stable-diffusion-2-base",
43
  torch_dtype=torch.float32,
44
  safety_checker=None
45
  ).to("cpu")
46
-
47
- return trans_tokenizer, trans_model, text_gen, img_pipe
48
-
49
 
50
  def translate_tamil(text, tokenizer, model):
51
- inputs = tokenizer(
52
- text,
53
- return_tensors="pt",
54
- padding=True,
55
- truncation=True,
56
- max_length=128
57
- )
58
-
59
- generated = model.generate(
60
- **inputs,
61
- max_length=150,
62
- num_beams=5,
63
- early_stopping=True
64
- )
65
-
66
- return tokenizer.batch_decode(
67
- generated,
68
- skip_special_tokens=True,
69
- clean_up_tokenization_spaces=True
70
- )[0]
71
 
72
  def main():
73
- load_css("style.css")
74
-
75
- # Header
76
- st.markdown("""
77
- <div class="header">
78
- <h1>🌐 தமிழ் → English → Creative Studio</h1>
79
- <p>Translate Tamil text and generate creative content</p>
80
- </div>
81
- """, unsafe_allow_html=True)
82
-
83
  tokenizer, model, text_gen, img_pipe = load_all_models()
84
-
85
  tamil_text = st.text_area("**தமிழ் உரை:**", height=150, placeholder="உங்கள் உரையை இங்கே உள்ளிடவும்...")
86
-
87
- if st.button("**உருவாக்கு**", type="primary", use_container_width=True):
88
  if not tamil_text.strip():
89
- st.warning("தயவு செய்து உரையை உள்ளிடவும்.")
90
- st.stop()
91
 
92
  with st.spinner("மொழிபெயர்க்கிறது..."):
93
  eng = translate_tamil(tamil_text, tokenizer, model)
94
-
95
- with st.expander("🔤 Translation", expanded=True):
96
- st.success(eng)
97
 
98
  with st.spinner("உரை உருவாக்குதல்..."):
99
- creative = text_gen(
100
- f"Create a creative description about: {eng}",
101
- max_length=80,
102
- num_return_sequences=1
103
- )[0]["generated_text"]
104
-
105
- st.info("📝 Creative Text:")
106
- st.write(creative)
107
 
108
  with st.spinner("படத்தை உருவாக்குதல்..."):
109
  img = img_pipe(eng, num_inference_steps=40, guidance_scale=8.5).images[0]
110
-
111
- st.image(img, caption="🎨 Generated Image", use_column_width=True)
112
 
113
  if __name__ == "__main__":
114
  main()
 
 
1
  import streamlit as st
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
 
 
 
6
  st.set_page_config(
7
  page_title="Tamil Creative Studio",
8
  page_icon="🇮🇳",
9
  layout="centered",
 
10
  )
11
 
12
+ def load_css():
13
+ st.markdown(
14
+ """<style>
15
+ .header {
16
+ text-align: center;
17
+ padding: 20px;
18
+ background: #f9f9f9;
19
+ border-radius: 10px;
20
+ margin-bottom: 20px;
21
+ }
22
+ .header h1 { color: #cc0000; }
23
+ .header p { color: #333; font-style: italic; }
24
+ </style>""",
25
+ unsafe_allow_html=True,
26
+ )
27
 
28
  @st.cache_resource(show_spinner=False)
29
  def load_all_models():
30
+ model_id = "ai4bharat/indictrans2-indic-en-dist-200M"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
33
  text_gen = pipeline("text-generation", model="gpt2", device=-1)
 
 
34
  img_pipe = StableDiffusionPipeline.from_pretrained(
35
  "stabilityai/stable-diffusion-2-base",
36
  torch_dtype=torch.float32,
37
  safety_checker=None
38
  ).to("cpu")
39
+ return tokenizer, model, text_gen, img_pipe
 
 
40
 
41
  def translate_tamil(text, tokenizer, model):
42
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
43
+ outs = model.generate(**inputs, max_length=150, num_beams=5, early_stopping=True)
44
+ return tokenizer.decode(outs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def main():
47
+ load_css()
48
+ st.markdown(
49
+ '<div class="header"><h1>🌐 தமிழ் → English → Creative Studio</h1>'
50
+ '<p>Translate Tamil text and generate creative content</p></div>',
51
+ unsafe_allow_html=True
52
+ )
 
 
 
 
53
  tokenizer, model, text_gen, img_pipe = load_all_models()
 
54
  tamil_text = st.text_area("**தமிழ் உரை:**", height=150, placeholder="உங்கள் உரையை இங்கே உள்ளிடவும்...")
55
+
56
+ if st.button("உருவாக்கு"):
57
  if not tamil_text.strip():
58
+ st.warning("உரையை உள்ளிடவும்.")
59
+ return
60
 
61
  with st.spinner("மொழிபெயர்க்கிறது..."):
62
  eng = translate_tamil(tamil_text, tokenizer, model)
63
+ st.success(eng)
 
 
64
 
65
  with st.spinner("உரை உருவாக்குதல்..."):
66
+ creative = text_gen(f"Create a creative description about: {eng}", max_length=80, num_return_sequences=1)[0]["generated_text"]
67
+ st.info(creative)
 
 
 
 
 
 
68
 
69
  with st.spinner("படத்தை உருவாக்குதல்..."):
70
  img = img_pipe(eng, num_inference_steps=40, guidance_scale=8.5).images[0]
71
+ st.image(img, caption="Generated Image", use_column_width=True)
 
72
 
73
  if __name__ == "__main__":
74
  main()