Update app.py
Browse files
app.py
CHANGED
@@ -5,10 +5,8 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
import torch
|
7 |
|
8 |
-
# Load environment variables
|
9 |
-
load_dotenv()
|
10 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
11 |
|
|
|
12 |
# Set Streamlit page config
|
13 |
st.set_page_config(
|
14 |
page_title="Tamil Creative Studio",
|
@@ -24,29 +22,31 @@ def load_css(file_name):
|
|
24 |
|
25 |
@st.cache_resource(show_spinner=False)
|
26 |
def load_all_models():
|
|
|
|
|
27 |
# Load translation model (private)
|
28 |
trans_tokenizer = AutoTokenizer.from_pretrained(
|
29 |
"ai4bharat/indictrans2-ta-en-dist-200M",
|
30 |
-
|
31 |
)
|
32 |
trans_model = AutoModelForSeq2SeqLM.from_pretrained(
|
33 |
"ai4bharat/indictrans2-ta-en-dist-200M",
|
34 |
-
|
35 |
)
|
36 |
-
|
37 |
# Load text generation model
|
38 |
text_gen = pipeline("text-generation", model="gpt2", device=-1)
|
39 |
-
|
40 |
# Load image generation model
|
41 |
img_pipe = StableDiffusionPipeline.from_pretrained(
|
42 |
"stabilityai/stable-diffusion-2-base",
|
43 |
-
use_auth_token=HF_TOKEN,
|
44 |
torch_dtype=torch.float32,
|
45 |
safety_checker=None
|
46 |
).to("cpu")
|
47 |
-
|
48 |
return trans_tokenizer, trans_model, text_gen, img_pipe
|
49 |
|
|
|
50 |
def translate_tamil(text, tokenizer, model):
|
51 |
inputs = tokenizer(
|
52 |
text,
|
|
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
import torch
|
7 |
|
|
|
|
|
|
|
8 |
|
9 |
+
load_dotenv()
|
10 |
# Set Streamlit page config
|
11 |
st.set_page_config(
|
12 |
page_title="Tamil Creative Studio",
|
|
|
22 |
|
23 |
@st.cache_resource(show_spinner=False)
|
24 |
def load_all_models():
|
25 |
+
HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face token from Hugging Face Spaces secret
|
26 |
+
|
27 |
# Load translation model (private)
|
28 |
trans_tokenizer = AutoTokenizer.from_pretrained(
|
29 |
"ai4bharat/indictrans2-ta-en-dist-200M",
|
30 |
+
token=HF_TOKEN
|
31 |
)
|
32 |
trans_model = AutoModelForSeq2SeqLM.from_pretrained(
|
33 |
"ai4bharat/indictrans2-ta-en-dist-200M",
|
34 |
+
token=HF_TOKEN
|
35 |
)
|
36 |
+
|
37 |
# Load text generation model
|
38 |
text_gen = pipeline("text-generation", model="gpt2", device=-1)
|
39 |
+
|
40 |
# Load image generation model
|
41 |
img_pipe = StableDiffusionPipeline.from_pretrained(
|
42 |
"stabilityai/stable-diffusion-2-base",
|
|
|
43 |
torch_dtype=torch.float32,
|
44 |
safety_checker=None
|
45 |
).to("cpu")
|
46 |
+
|
47 |
return trans_tokenizer, trans_model, text_gen, img_pipe
|
48 |
|
49 |
+
|
50 |
def translate_tamil(text, tokenizer, model):
|
51 |
inputs = tokenizer(
|
52 |
text,
|