24Sureshkumar commited on
Commit
39c8921
verified
1 Parent(s): 635ae0d

update_app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -28
app.py CHANGED
@@ -1,53 +1,44 @@
1
- # app.py
2
  import streamlit as st
3
  import torch
4
  import openai
5
  import os
6
  import time
7
- import requests
8
  from PIL import Image
9
  import tempfile
10
  import clip # from OpenAI CLIP repo
11
  import torch.nn.functional as F
12
- from transformers import (
13
- MBartForConditionalGeneration,
14
- MBart50TokenizerFast,
15
- AutoTokenizer,
16
- AutoModelForCausalLM,
17
- )
18
  from rouge_score import rouge_scorer
19
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
20
 
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
- openai.api_key = os.getenv("OPENAI_API_KEY") # Make sure this is set in your environment
23
 
24
- # Load MBart model
25
  translator_model = MBartForConditionalGeneration.from_pretrained(
26
  "facebook/mbart-large-50-many-to-many-mmt",
27
  device_map="auto",
28
  low_cpu_mem_usage=True
29
  )
30
- translator_tokenizer = MBart50TokenizerFast.from_pretrained(
31
- "facebook/mbart-large-50-many-to-many-mmt"
32
- )
33
  translator_tokenizer.src_lang = "ta_IN"
34
 
35
- # Load GPT-2 model
36
- gen_model = AutoModelForCausalLM.from_pretrained(
37
- "gpt2",
38
- device_map="auto",
39
- low_cpu_mem_usage=True
40
- )
41
  gen_model.eval()
42
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
43
 
44
- # Load CLIP model
 
45
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
46
 
47
- # ---- Translation Function ----
48
  def translate_tamil_to_english(text, reference=None):
49
  start = time.time()
50
- inputs = translator_tokenizer(text, return_tensors="pt").to(device)
 
51
  outputs = translator_model.generate(
52
  **inputs,
53
  forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
@@ -63,10 +54,12 @@ def translate_tamil_to_english(text, reference=None):
63
 
64
  return translated, duration, rouge_l
65
 
66
- # ---- Creative Text Generation ----
67
  def generate_creative_text(prompt, max_length=100):
68
  start = time.time()
69
- input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
70
  output = gen_model.generate(
71
  input_ids,
72
  max_length=max_length,
@@ -81,14 +74,14 @@ def generate_creative_text(prompt, max_length=100):
81
  rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
82
 
83
  with torch.no_grad():
84
- input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
85
  outputs = gen_model(input_ids, labels=input_ids)
86
  loss = outputs.loss
87
  perplexity = torch.exp(loss).item()
88
 
89
  return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
90
 
91
- # ---- Image Generation using DALL路E 3 ----
92
  def generate_image(prompt):
93
  try:
94
  start = time.time()
@@ -102,7 +95,6 @@ def generate_image(prompt):
102
  image_url = response.data[0].url
103
  image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
104
 
105
- # Save to temporary file
106
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
107
  image_data.save(tmp_file.name)
108
  duration = round(time.time() - start, 2)
 
 
1
  import streamlit as st
2
  import torch
3
  import openai
4
  import os
5
  import time
 
6
  from PIL import Image
7
  import tempfile
8
  import clip # from OpenAI CLIP repo
9
  import torch.nn.functional as F
10
+ import requests
11
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
13
  from rouge_score import rouge_scorer
14
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
15
 
16
+ # Set device
17
+ openai.api_key = os.getenv("OPENAI_API_KEY")
18
 
19
+ # Load MBart
20
  translator_model = MBartForConditionalGeneration.from_pretrained(
21
  "facebook/mbart-large-50-many-to-many-mmt",
22
  device_map="auto",
23
  low_cpu_mem_usage=True
24
  )
25
+ translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
 
 
26
  translator_tokenizer.src_lang = "ta_IN"
27
 
28
+ # Load GPT-2
29
+ gen_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", low_cpu_mem_usage=True)
 
 
 
 
30
  gen_model.eval()
31
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
32
 
33
+ # Load CLIP
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
36
 
37
+ # ---- Translation ----
38
  def translate_tamil_to_english(text, reference=None):
39
  start = time.time()
40
+ inputs = translator_tokenizer(text, return_tensors="pt")
41
+ inputs = {k: v.to(translator_model.device) for k, v in inputs.items()}
42
  outputs = translator_model.generate(
43
  **inputs,
44
  forced_bos_token_id=translator_tokenizer.lang_code_to_id["en_XX"]
 
54
 
55
  return translated, duration, rouge_l
56
 
57
+ # ---- Creative Text ----
58
  def generate_creative_text(prompt, max_length=100):
59
  start = time.time()
60
+ input_ids = gen_tokenizer.encode(prompt, return_tensors="pt")
61
+ input_ids = input_ids.to(gen_model.device)
62
+
63
  output = gen_model.generate(
64
  input_ids,
65
  max_length=max_length,
 
74
  rep_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens) if len(tokens) > 1 else 0
75
 
76
  with torch.no_grad():
77
+ input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(gen_model.device)
78
  outputs = gen_model(input_ids, labels=input_ids)
79
  loss = outputs.loss
80
  perplexity = torch.exp(loss).item()
81
 
82
  return text, duration, len(tokens), round(rep_rate, 4), round(perplexity, 4)
83
 
84
+ # ---- Image Generation ----
85
  def generate_image(prompt):
86
  try:
87
  start = time.time()
 
95
  image_url = response.data[0].url
96
  image_data = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
97
 
 
98
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
99
  image_data.save(tmp_file.name)
100
  duration = round(time.time() - start, 2)