Update app.py
Browse files
app.py
CHANGED
@@ -3,42 +3,44 @@ import torch
|
|
3 |
import openai
|
4 |
import os
|
5 |
import time
|
|
|
6 |
from PIL import Image
|
7 |
import tempfile
|
8 |
-
import clip
|
9 |
import torch.nn.functional as F
|
10 |
-
import requests
|
11 |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
12 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
13 |
from rouge_score import rouge_scorer
|
14 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
15 |
|
16 |
# Set device
|
|
|
|
|
|
|
17 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
18 |
|
19 |
-
# Load MBart
|
20 |
translator_model = MBartForConditionalGeneration.from_pretrained(
|
21 |
-
"facebook/mbart-large-50-many-to-many-mmt"
|
22 |
-
|
23 |
-
|
|
|
24 |
)
|
25 |
-
|
26 |
translator_tokenizer.src_lang = "ta_IN"
|
27 |
|
28 |
-
#
|
29 |
-
gen_model =
|
30 |
-
gen_model.eval()
|
31 |
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
|
32 |
|
33 |
-
#
|
34 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
36 |
|
37 |
# ---- Translation ----
|
38 |
def translate_tamil_to_english(text, reference=None):
|
39 |
start = time.time()
|
40 |
-
inputs = translator_tokenizer(text, return_tensors="pt")
|
41 |
-
inputs = {k: v.to(translator_model.device) for k, v in inputs.items()}
|
42 |
outputs = translator_model.generate(
|
43 |
**inputs,
|
44 |
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
|
@@ -57,9 +59,7 @@ def translate_tamil_to_english(text, reference=None):
|
|
57 |
# ---- Creative Text ----
|
58 |
def generate_creative_text(prompt, max_length=100):
|
59 |
start = time.time()
|
60 |
-
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt")
|
61 |
-
input_ids = input_ids.to(gen_model.device)
|
62 |
-
|
63 |
output = gen_model.generate(
|
64 |
input_ids,
|
65 |
max_length=max_length,
|
@@ -74,7 +74,7 @@ def generate_creative_text(prompt, max_length=100):
|
|
74 |
rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
|
75 |
|
76 |
with torch.no_grad():
|
77 |
-
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(
|
78 |
outputs = gen_model(input_ids, labels=input_ids)
|
79 |
loss = outputs.loss
|
80 |
perplexity = torch.exp(loss).item()
|
@@ -99,7 +99,6 @@ def generate_image(prompt):
|
|
99 |
image_data.save(tmp_file.name)
|
100 |
duration = round(time.time() - start, 2)
|
101 |
|
102 |
-
# CLIP similarity
|
103 |
image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
|
104 |
text_input = clip.tokenize([prompt]).to(device)
|
105 |
with torch.no_grad():
|
@@ -111,7 +110,7 @@ def generate_image(prompt):
|
|
111 |
except Exception as e:
|
112 |
return None, None, f"Image generation failed: {str(e)}"
|
113 |
|
114 |
-
# ----
|
115 |
st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
|
116 |
st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
|
117 |
|
|
|
3 |
import openai
|
4 |
import os
|
5 |
import time
|
6 |
+
import requests
|
7 |
from PIL import Image
|
8 |
import tempfile
|
9 |
+
import clip
|
10 |
import torch.nn.functional as F
|
|
|
11 |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
12 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
|
13 |
from rouge_score import rouge_scorer
|
|
|
14 |
|
15 |
# Set device
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
# OpenAI Key
|
19 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
20 |
|
21 |
+
# ---- Load MBart (Translation) ----
|
22 |
translator_model = MBartForConditionalGeneration.from_pretrained(
|
23 |
+
"facebook/mbart-large-50-many-to-many-mmt"
|
24 |
+
)
|
25 |
+
translator_tokenizer = MBart50TokenizerFast.from_pretrained(
|
26 |
+
"facebook/mbart-large-50-many-to-many-mmt"
|
27 |
)
|
28 |
+
translator_model.to(device)
|
29 |
translator_tokenizer.src_lang = "ta_IN"
|
30 |
|
31 |
+
# ---- GPT-2 ----
|
32 |
+
gen_model = GPT2LMHeadModel.from_pretrained("gpt2")
|
|
|
33 |
gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
34 |
+
gen_model.to(device)
|
35 |
+
gen_model.eval()
|
36 |
|
37 |
+
# ---- CLIP ----
|
|
|
38 |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
39 |
|
40 |
# ---- Translation ----
|
41 |
def translate_tamil_to_english(text, reference=None):
|
42 |
start = time.time()
|
43 |
+
inputs = translator_tokenizer(text, return_tensors="pt").to(device)
|
|
|
44 |
outputs = translator_model.generate(
|
45 |
**inputs,
|
46 |
forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
|
|
|
59 |
# ---- Creative Text ----
|
60 |
def generate_creative_text(prompt, max_length=100):
|
61 |
start = time.time()
|
62 |
+
input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
|
|
63 |
output = gen_model.generate(
|
64 |
input_ids,
|
65 |
max_length=max_length,
|
|
|
74 |
rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
|
75 |
|
76 |
with torch.no_grad():
|
77 |
+
input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
|
78 |
outputs = gen_model(input_ids, labels=input_ids)
|
79 |
loss = outputs.loss
|
80 |
perplexity = torch.exp(loss).item()
|
|
|
99 |
image_data.save(tmp_file.name)
|
100 |
duration = round(time.time() - start, 2)
|
101 |
|
|
|
102 |
image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
|
103 |
text_input = clip.tokenize([prompt]).to(device)
|
104 |
with torch.no_grad():
|
|
|
110 |
except Exception as e:
|
111 |
return None, None, f"Image generation failed: {str(e)}"
|
112 |
|
113 |
+
# ---- UI ----
|
114 |
st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
|
115 |
st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
|
116 |
|