24Sureshkumar commited on
Commit
4569162
·
verified ·
1 Parent(s): 2d8ff37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -41
app.py CHANGED
@@ -1,57 +1,54 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
 
6
  @st.cache_resource
7
  def load_all_models():
8
- # Load IndicTrans2 Tamil-to-English model
9
  trans_model_id = "ai4bharat/indictrans2-indic-en-dist-200M"
10
- trans_tokenizer = AutoTokenizer.from_pretrained(trans_model_id, trust_remote_code=True)
11
- trans_model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id, trust_remote_code=True)
 
12
 
13
- # Load English text generation model (you can use GPT2 or Falcon, etc.)
14
- text_gen = pipeline("text-generation", model="gpt2")
15
-
16
- # Load Stable Diffusion for image generation
17
  img_pipe = StableDiffusionPipeline.from_pretrained(
18
- "runwayml/stable-diffusion-v1-5",
19
- torch_dtype=torch.float16,
20
- revision="fp16"
21
- ).to("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- return trans_tokenizer, trans_model, text_gen, img_pipe
24
 
25
- def translate_text(text, tokenizer, model):
26
- input_text = f"translate Tamil to English: {text}"
27
- inputs = tokenizer(input_text, return_tensors="pt", padding=True)
28
- outputs = model.generate(**inputs, max_length=128)
29
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
30
 
 
31
  def main():
32
- st.title("Multimodal Tamil to Image Generator 🚀")
33
- st.markdown("Enter Tamil text, we translate it to English, continue the sentence, and generate an image!")
34
-
35
- user_input = st.text_area("Enter Tamil text:", "")
36
-
37
- if st.button("Generate"):
38
- with st.spinner("Loading models..."):
39
- tokenizer, model, text_gen, img_pipe = load_all_models()
40
-
 
 
 
 
 
 
 
 
41
  with st.spinner("Translating to English..."):
42
- english_text = translate_text(user_input, tokenizer, model)
43
- st.subheader("Translated English:")
44
- st.write(english_text)
45
-
46
- with st.spinner("Generating continuation..."):
47
- continuation = text_gen(english_text, max_length=50, do_sample=True)[0]['generated_text']
48
- st.subheader("Generated Text:")
49
- st.write(continuation)
50
-
51
- with st.spinner("Generating Image..."):
52
- image = img_pipe(continuation).images[0]
53
- st.subheader("Generated Image:")
54
- st.image(image)
55
 
56
  if __name__ == "__main__":
57
  main()
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
6
+ # Load models only once
7
  @st.cache_resource
8
  def load_all_models():
9
+ # Load translation model
10
  trans_model_id = "ai4bharat/indictrans2-indic-en-dist-200M"
11
+ tokenizer = AutoTokenizer.from_pretrained(trans_model_id, trust_remote_code=True)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(trans_model_id, trust_remote_code=True)
13
+ translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer)
14
 
15
+ # Load image generation model (Stable Diffusion 2.1)
 
 
 
16
  img_pipe = StableDiffusionPipeline.from_pretrained(
17
+ "stabilityai/stable-diffusion-2-1",
18
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
19
+ revision="fp16" if torch.cuda.is_available() else None,
20
+ )
21
+ img_pipe = img_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
 
22
 
23
+ return tokenizer, model, translation_pipeline, img_pipe
 
 
 
 
24
 
25
+ # Streamlit UI
26
  def main():
27
+ st.set_page_config(page_title="Tamil to English to Image Generator", layout="centered")
28
+ st.title("📸 Tamil → English → AI Image Generator")
29
+ st.markdown("Translate Tamil text to English and generate an image from it!")
30
+
31
+ # Load models
32
+ with st.spinner("Loading models..."):
33
+ tokenizer, model, translation_pipeline, img_pipe = load_all_models()
34
+
35
+ # Input
36
+ tamil_text = st.text_area("Enter Tamil text here:", height=150)
37
+
38
+ if st.button("Generate Image"):
39
+ if tamil_text.strip() == "":
40
+ st.warning("Please enter some Tamil text.")
41
+ return
42
+
43
+ # Step 1: Translate Tamil to English
44
  with st.spinner("Translating to English..."):
45
+ translated = translation_pipeline(tamil_text, src_lang="ta", tgt_lang="en")[0]["translation_text"]
46
+ st.success(f"🔤 English Translation: `{translated}`")
47
+
48
+ # Step 2: Generate image
49
+ with st.spinner("Generating image..."):
50
+ image = img_pipe(prompt=translated).images[0]
51
+ st.image(image, caption="Generated Image", use_column_width=True)
 
 
 
 
 
 
52
 
53
  if __name__ == "__main__":
54
  main()