|
import os |
|
import torch |
|
from transformers import MBartForConditionalGeneration, MBart50Tokenizer, AutoTokenizer, AutoModelForCausalLM |
|
from diffusers import StableDiffusionPipeline |
|
from PIL import Image |
|
import tempfile |
|
import time |
|
import streamlit as st |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
translator_tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device) |
|
translator_tokenizer.src_lang = "ta_IN" |
|
|
|
|
|
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device) |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-1-base", |
|
torch_dtype=torch.float32, |
|
safety_checker=None |
|
).to(device) |
|
|
|
def translate_tamil_to_english(text): |
|
inputs = translator_tokenizer(text, return_tensors="pt").to(device) |
|
output = translator_model.generate( |
|
**inputs, |
|
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"] |
|
) |
|
translated = translator_tokenizer.batch_decode(output, skip_special_tokens=True)[0] |
|
return translated |
|
|
|
def generate_creative_text(prompt, max_length=100): |
|
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
output = gen_model.generate( |
|
input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9 |
|
) |
|
return gen_tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
def generate_image(prompt): |
|
image = pipe(prompt).images[0] |
|
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) |
|
image.save(temp_file.name) |
|
return temp_file.name |
|
|
|
|
|
st.set_page_config(page_title="Tamil β English + AI", layout="centered") |
|
st.title("π Tamil to English + AI Image Generator") |
|
|
|
tamil_input = st.text_area("βοΈ Enter Tamil Text", height=150) |
|
|
|
if st.button("π Generate"): |
|
if not tamil_input.strip(): |
|
st.warning("Please enter Tamil text.") |
|
else: |
|
with st.spinner("Translating..."): |
|
translated = translate_tamil_to_english(tamil_input) |
|
st.success("β
Translated!") |
|
st.markdown(f"**English:** `{translated}`") |
|
|
|
with st.spinner("Generating creative text..."): |
|
creative_text = generate_creative_text(translated) |
|
st.success("β
Creative text generated!") |
|
st.markdown(f"**Creative Prompt:** `{creative_text}`") |
|
|
|
with st.spinner("Generating image..."): |
|
image_path = generate_image(translated) |
|
st.success("β
Image generated!") |
|
st.image(Image.open(image_path), caption="πΌοΈ AI Generated Image", use_column_width=True) |
|
|
|
st.markdown("---") |
|
st.markdown("π§ Powered by MBart, GPT2 & Stable Diffusion - Deployed on Hugging Face π€") |
|
|