import os import streamlit as st from dotenv import load_dotenv from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from diffusers import StableDiffusionPipeline import torch # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") # Set Streamlit page config st.set_page_config( page_title="Tamil Creative Studio", page_icon="🇮🇳", layout="centered", initial_sidebar_state="collapsed" ) # Load custom CSS def load_css(file_name): with open(file_name, "r") as f: st.markdown(f"", unsafe_allow_html=True) @st.cache_resource(show_spinner=False) def load_all_models(): # Load translation model (private) trans_tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/indictrans2-ta-en-dist-200M", use_auth_token=HF_TOKEN ) trans_model = AutoModelForSeq2SeqLM.from_pretrained( "ai4bharat/indictrans2-ta-en-dist-200M", use_auth_token=HF_TOKEN ) # Load text generation model text_gen = pipeline("text-generation", model="gpt2", device=-1) # Load image generation model img_pipe = StableDiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-2-base", use_auth_token=HF_TOKEN, torch_dtype=torch.float32, safety_checker=None ).to("cpu") return trans_tokenizer, trans_model, text_gen, img_pipe def translate_tamil(text, tokenizer, model): inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=128 ) generated = model.generate( **inputs, max_length=150, num_beams=5, early_stopping=True ) return tokenizer.batch_decode( generated, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] def main(): load_css("style.css") # Header st.markdown("""
Translate Tamil text and generate creative content