24Sureshkumar commited on
Commit
396e877
·
verified ·
1 Parent(s): 39c8921

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -22
app.py CHANGED
@@ -3,42 +3,44 @@ import torch
3
  import openai
4
  import os
5
  import time
 
6
  from PIL import Image
7
  import tempfile
8
- import clip # from OpenAI CLIP repo
9
  import torch.nn.functional as F
10
- import requests
11
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
13
  from rouge_score import rouge_scorer
14
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
15
 
16
  # Set device
 
 
 
17
  openai.api_key = os.getenv("OPENAI_API_KEY")
18
 
19
- # Load MBart
20
  translator_model = MBartForConditionalGeneration.from_pretrained(
21
- "facebook/mbart-large-50-many-to-many-mmt",
22
- device_map="auto",
23
- low_cpu_mem_usage=True
 
24
  )
25
- translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
26
  translator_tokenizer.src_lang = "ta_IN"
27
 
28
- # Load GPT-2
29
- gen_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", low_cpu_mem_usage=True)
30
- gen_model.eval()
31
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
 
32
 
33
- # Load CLIP
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
36
 
37
  # ---- Translation ----
38
  def translate_tamil_to_english(text, reference=None):
39
  start = time.time()
40
- inputs = translator_tokenizer(text, return_tensors="pt")
41
- inputs = {k: v.to(translator_model.device) for k, v in inputs.items()}
42
  outputs = translator_model.generate(
43
  **inputs,
44
  forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
@@ -57,9 +59,7 @@ def translate_tamil_to_english(text, reference=None):
57
  # ---- Creative Text ----
58
  def generate_creative_text(prompt, max_length=100):
59
  start = time.time()
60
- input_ids = gen_tokenizer.encode(prompt, return_tensors="pt")
61
- input_ids = input_ids.to(gen_model.device)
62
-
63
  output = gen_model.generate(
64
  input_ids,
65
  max_length=max_length,
@@ -74,7 +74,7 @@ def generate_creative_text(prompt, max_length=100):
74
  rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
75
 
76
  with torch.no_grad():
77
- input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(gen_model.device)
78
  outputs = gen_model(input_ids, labels=input_ids)
79
  loss = outputs.loss
80
  perplexity = torch.exp(loss).item()
@@ -99,7 +99,6 @@ def generate_image(prompt):
99
  image_data.save(tmp_file.name)
100
  duration = round(time.time() - start, 2)
101
 
102
- # CLIP similarity
103
  image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
104
  text_input = clip.tokenize([prompt]).to(device)
105
  with torch.no_grad():
@@ -111,7 +110,7 @@ def generate_image(prompt):
111
  except Exception as e:
112
  return None, None, f"Image generation failed: {str(e)}"
113
 
114
- # ---- Streamlit UI ----
115
  st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
116
  st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
117
 
 
3
  import openai
4
  import os
5
  import time
6
+ import requests
7
  from PIL import Image
8
  import tempfile
9
+ import clip
10
  import torch.nn.functional as F
 
11
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
13
  from rouge_score import rouge_scorer
 
14
 
15
  # Set device
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # OpenAI Key
19
  openai.api_key = os.getenv("OPENAI_API_KEY")
20
 
21
+ # ---- Load MBart (Translation) ----
22
  translator_model = MBartForConditionalGeneration.from_pretrained(
23
+ "facebook/mbart-large-50-many-to-many-mmt"
24
+ )
25
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained(
26
+ "facebook/mbart-large-50-many-to-many-mmt"
27
  )
28
+ translator_model.to(device)
29
  translator_tokenizer.src_lang = "ta_IN"
30
 
31
+ # ---- GPT-2 ----
32
+ gen_model = GPT2LMHeadModel.from_pretrained("gpt2")
 
33
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
34
+ gen_model.to(device)
35
+ gen_model.eval()
36
 
37
+ # ---- CLIP ----
 
38
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
39
 
40
  # ---- Translation ----
41
  def translate_tamil_to_english(text, reference=None):
42
  start = time.time()
43
+ inputs = translator_tokenizer(text, return_tensors="pt").to(device)
 
44
  outputs = translator_model.generate(
45
  **inputs,
46
  forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
 
59
  # ---- Creative Text ----
60
  def generate_creative_text(prompt, max_length=100):
61
  start = time.time()
62
+ input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
63
  output = gen_model.generate(
64
  input_ids,
65
  max_length=max_length,
 
74
  rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
75
 
76
  with torch.no_grad():
77
+ input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
78
  outputs = gen_model(input_ids, labels=input_ids)
79
  loss = outputs.loss
80
  perplexity = torch.exp(loss).item()
 
99
  image_data.save(tmp_file.name)
100
  duration = round(time.time() - start, 2)
101
 
 
102
  image_input = clip_preprocess(image_data).unsqueeze(0).to(device)
103
  text_input = clip.tokenize([prompt]).to(device)
104
  with torch.no_grad():
 
110
  except Exception as e:
111
  return None, None, f"Image generation failed: {str(e)}"
112
 
113
+ # ---- UI ----
114
  st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
115
  st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
116