File size: 3,564 Bytes
d16c1e4
1437058
d16c1e4
a1cf7cb
d550533
eea6ac5
d16c1e4
1437058
 
d16c1e4
54da9ab
1437058
54da9ab
 
 
 
 
 
 
 
 
 
 
4deb4b5
d7164de
 
1437058
d16c1e4
1437058
 
d16c1e4
 
1437058
 
d16c1e4
1437058
54da9ab
1437058
 
54da9ab
d7164de
d16c1e4
1437058
54da9ab
1437058
d7164de
1437058
54da9ab
d7164de
 
a1cf7cb
 
 
 
 
 
 
1437058
d7164de
 
a1cf7cb
 
 
d550533
1437058
a1cf7cb
1437058
a1cf7cb
 
 
9607ff2
c1732d5
54da9ab
1437058
 
 
54da9ab
 
 
 
1437058
 
d7164de
1437058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9607ff2
 
d16c1e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import streamlit as st
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from diffusers import StableDiffusionPipeline
import torch

# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# Set Streamlit page config
st.set_page_config(
    page_title="Tamil Creative Studio",
    page_icon="🇮🇳",
    layout="centered",
    initial_sidebar_state="collapsed"
)

# Load custom CSS
def load_css(file_name):
    with open(file_name, "r") as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

@st.cache_resource(show_spinner=False)
def load_all_models():
    # Load translation model (private)
    trans_tokenizer = AutoTokenizer.from_pretrained(
        "ai4bharat/indictrans2-ta-en-dist-200M",
        use_auth_token=HF_TOKEN
    )
    trans_model = AutoModelForSeq2SeqLM.from_pretrained(
        "ai4bharat/indictrans2-ta-en-dist-200M",
        use_auth_token=HF_TOKEN
    )
    
    # Load text generation model
    text_gen = pipeline("text-generation", model="gpt2", device=-1)
    
    # Load image generation model
    img_pipe = StableDiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-base",
        use_auth_token=HF_TOKEN,
        torch_dtype=torch.float32,
        safety_checker=None
    ).to("cpu")
    
    return trans_tokenizer, trans_model, text_gen, img_pipe

def translate_tamil(text, tokenizer, model):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128
    )
    
    generated = model.generate(
        **inputs,
        max_length=150,
        num_beams=5,
        early_stopping=True
    )
    
    return tokenizer.batch_decode(
        generated, 
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )[0]

def main():
    load_css("style.css")
    
    # Header
    st.markdown("""
        <div class="header">
            <h1>🌐 தமிழ் → English → Creative Studio</h1>
            <p>Translate Tamil text and generate creative content</p>
        </div>
    """, unsafe_allow_html=True)
    
    tokenizer, model, text_gen, img_pipe = load_all_models()
    
    tamil_text = st.text_area("**தமிழ் உரை:**", height=150, placeholder="உங்கள் உரையை இங்கே உள்ளிடவும்...")
    
    if st.button("**உருவாக்கு**", type="primary", use_container_width=True):
        if not tamil_text.strip():
            st.warning("தயவு செய்து உரையை உள்ளிடவும்.")
            st.stop()

        with st.spinner("மொழிபெயர்க்கிறது..."):
            eng = translate_tamil(tamil_text, tokenizer, model)

        with st.expander("🔤 Translation", expanded=True):
            st.success(eng)

        with st.spinner("உரை உருவாக்குதல்..."):
            creative = text_gen(
                f"Create a creative description about: {eng}",
                max_length=80,
                num_return_sequences=1
            )[0]["generated_text"]

        st.info("📝 Creative Text:")
        st.write(creative)

        with st.spinner("படத்தை உருவாக்குதல்..."):
            img = img_pipe(eng, num_inference_steps=40, guidance_scale=8.5).images[0]

        st.image(img, caption="🎨 Generated Image", use_column_width=True)

if __name__ == "__main__":
    main()