24Sureshkumar commited on
Commit
96b6780
·
verified ·
1 Parent(s): d078cab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -1,38 +1,50 @@
 
1
  import streamlit as st
2
  import torch
3
  import openai
4
  import os
5
  import time
 
6
  from PIL import Image
7
  import tempfile
8
  import clip # from OpenAI CLIP repo
9
  import torch.nn.functional as F
10
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel
 
 
 
 
12
  from rouge_score import rouge_scorer
13
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- openai.api_key = os.getenv("OPENAI_API_KEY") # Set this from env
17
 
18
- # Load MBart
19
  translator_model = MBartForConditionalGeneration.from_pretrained(
20
- "facebook/mbart-large-50-many-to-many-mmt"
21
- ).to(device)
 
 
22
  translator_tokenizer = MBart50TokenizerFast.from_pretrained(
23
  "facebook/mbart-large-50-many-to-many-mmt"
24
  )
25
  translator_tokenizer.src_lang = "ta_IN"
26
 
27
- # GPT-2
28
- gen_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
 
 
 
 
29
  gen_model.eval()
30
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
31
 
32
- # CLIP
33
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
34
 
35
- # ---- Translation ----
36
  def translate_tamil_to_english(text, reference=None):
37
  start = time.time()
38
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
@@ -51,7 +63,7 @@ def translate_tamil_to_english(text, reference=None):
51
 
52
  return translated, duration, rouge_l
53
 
54
- # ---- Creative Text ----
55
  def generate_creative_text(prompt, max_length=100):
56
  start = time.time()
57
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
@@ -88,10 +100,9 @@ def generate_image(prompt):
88
  n=1
89
  )
90
  image_url = response.data[0].url
91
- image_data = Image.open(tempfile.NamedTemporaryFile(delete=False, suffix=".png"))
92
  image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
93
 
94
- # Save locally
95
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
96
  image_data.save(tmp_file.name)
97
  duration = round(time.time() - start, 2)
@@ -108,7 +119,7 @@ def generate_image(prompt):
108
  except Exception as e:
109
  return None, None, f"Image generation failed: {str(e)}"
110
 
111
- # ---- UI ----
112
  st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
113
  st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
114
 
 
1
+ # app.py
2
  import streamlit as st
3
  import torch
4
  import openai
5
  import os
6
  import time
7
+ import requests
8
  from PIL import Image
9
  import tempfile
10
  import clip # from OpenAI CLIP repo
11
  import torch.nn.functional as F
12
+ from transformers import (
13
+ MBartForConditionalGeneration,
14
+ MBart50TokenizerFast,
15
+ AutoTokenizer,
16
+ AutoModelForCausalLM,
17
+ )
18
  from rouge_score import rouge_scorer
19
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ openai.api_key = os.getenv("OPENAI_API_KEY") # Make sure this is set in your environment
23
 
24
+ # Load MBart model
25
  translator_model = MBartForConditionalGeneration.from_pretrained(
26
+ "facebook/mbart-large-50-many-to-many-mmt",
27
+ device_map="auto",
28
+ low_cpu_mem_usage=True
29
+ )
30
  translator_tokenizer = MBart50TokenizerFast.from_pretrained(
31
  "facebook/mbart-large-50-many-to-many-mmt"
32
  )
33
  translator_tokenizer.src_lang = "ta_IN"
34
 
35
+ # Load GPT-2 model
36
+ gen_model = AutoModelForCausalLM.from_pretrained(
37
+ "gpt2",
38
+ device_map="auto",
39
+ low_cpu_mem_usage=True
40
+ )
41
  gen_model.eval()
42
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
43
 
44
+ # Load CLIP model
45
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
46
 
47
+ # ---- Translation Function ----
48
  def translate_tamil_to_english(text, reference=None):
49
  start = time.time()
50
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
 
63
 
64
  return translated, duration, rouge_l
65
 
66
+ # ---- Creative Text Generation ----
67
  def generate_creative_text(prompt, max_length=100):
68
  start = time.time()
69
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
 
100
  n=1
101
  )
102
  image_url = response.data[0].url
 
103
  image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
104
 
105
+ # Save to temporary file
106
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
107
  image_data.save(tmp_file.name)
108
  duration = round(time.time() - start, 2)
 
119
  except Exception as e:
120
  return None, None, f"Image generation failed: {str(e)}"
121
 
122
+ # ---- Streamlit UI ----
123
  st.set_page_config(page_title="Tamil → English + AI Art", layout="centered")
124
  st.title("🧠 Tamil → English + 🎨 Creative Text + 🖼️ AI Image")
125