24Sureshkumar commited on
Commit
67241c5
·
verified ·
1 Parent(s): b231e51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -24
app.py CHANGED
@@ -1,18 +1,16 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
6
- # Load models only once
7
  @st.cache_resource
8
  def load_all_models():
9
  # Load translation model
10
- trans_model_id = "ai4bharat/indictrans2-indic-en-dist-200M"
11
- tokenizer = AutoTokenizer.from_pretrained(trans_model_id, trust_remote_code=True)
12
- model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id, trust_remote_code=True)
13
- translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
14
 
15
- # Load image generation model (Stable Diffusion 2.1)
16
  img_pipe = StableDiffusionPipeline.from_pretrained(
17
  "stabilityai/stable-diffusion-2-1",
18
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -20,35 +18,33 @@ def load_all_models():
20
  )
21
  img_pipe = img_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
- return tokenizer, model, translation_pipeline, img_pipe
24
 
25
- # Streamlit UI
26
  def main():
27
- st.set_page_config(page_title="Tamil to English to Image Generator", layout="centered")
28
- st.title("📸 Tamil English AI Image Generator")
29
- st.markdown("Translate Tamil text to English and generate an image from it!")
30
 
31
- # Load models
32
- with st.spinner("Loading models..."):
33
- tokenizer, model, translation_pipeline, img_pipe = load_all_models()
34
 
35
- # Input
36
- tamil_text = st.text_area("Enter Tamil text here:", height=150)
37
-
38
  if st.button("Generate Image"):
39
- if tamil_text.strip() == "":
40
  st.warning("Please enter some Tamil text.")
41
  return
42
 
43
- # Step 1: Translate Tamil to English
44
- with st.spinner("Translating to English..."):
45
- translated = translation_pipeline(tamil_text, src_lang="ta", tgt_lang="en")[0]["translation_text"]
 
 
 
 
 
 
46
  st.success(f"🔤 English Translation: `{translated}`")
47
 
48
- # Step 2: Generate image
49
  with st.spinner("Generating image..."):
50
  image = img_pipe(prompt=translated).images[0]
51
- st.image(image, caption="Generated Image", use_column_width=True)
52
 
53
  if __name__ == "__main__":
54
  main()
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
 
6
  @st.cache_resource
7
  def load_all_models():
8
  # Load translation model
9
+ model_id = "ai4bharat/indictrans2-indic-en-dist-200M"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=True)
 
12
 
13
+ # Load Stable Diffusion image generator
14
  img_pipe = StableDiffusionPipeline.from_pretrained(
15
  "stabilityai/stable-diffusion-2-1",
16
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
18
  )
19
  img_pipe = img_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
+ return tokenizer, model, img_pipe
22
 
 
23
  def main():
24
+ st.set_page_config(page_title="Tamil to English to Image", layout="centered")
25
+ st.title("📸 Tamil English AI Image Generator")
 
26
 
27
+ tamil_text = st.text_area("Enter Tamil text:", height=150)
 
 
28
 
 
 
 
29
  if st.button("Generate Image"):
30
+ if not tamil_text.strip():
31
  st.warning("Please enter some Tamil text.")
32
  return
33
 
34
+ with st.spinner("Loading models..."):
35
+ tokenizer, model, img_pipe = load_all_models()
36
+
37
+ with st.spinner("Translating Tamil to English..."):
38
+ # Prepare special format: "<2en> <tamil sentence>"
39
+ formatted_input = f"<2en> {tamil_text.strip()}"
40
+ inputs = tokenizer(formatted_input, return_tensors="pt")
41
+ output_ids = model.generate(**inputs)
42
+ translated = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
43
  st.success(f"🔤 English Translation: `{translated}`")
44
 
 
45
  with st.spinner("Generating image..."):
46
  image = img_pipe(prompt=translated).images[0]
47
+ st.image(image, caption="🖼️ AI-generated Image", use_column_width=True)
48
 
49
  if __name__ == "__main__":
50
  main()