24Sureshkumar commited on
Commit
c3b581c
Β·
verified Β·
1 Parent(s): 792fbb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -35
app.py CHANGED
@@ -5,36 +5,36 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from diffusers import StableDiffusionPipeline
6
  from rouge_score import rouge_scorer
7
  from PIL import Image
 
8
  import tempfile
9
  import os
 
10
  import time
11
- import clip
12
- import torchvision.transforms as transforms
13
 
14
- # Use CUDA if available
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Load translation model (Tamil to English)
18
  translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
19
  translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
20
  translator_tokenizer.src_lang = "ta_IN"
21
 
22
- # Load GPT-2 for creative text generation
23
  gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
24
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
25
 
26
- # Load a lightweight image generation model
27
  pipe = StableDiffusionPipeline.from_pretrained(
28
- "OFA-Sys/small-stable-diffusion-v0",
29
  torch_dtype=torch.float32,
30
- use_auth_token=os.getenv("HF_TOKEN") # Set in Hugging Face Space secrets
31
  ).to(device)
32
- pipe.safety_checker = None # Optional: disable for speed
33
 
34
- # Load CLIP model for image-text similarity
35
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
36
 
37
- # Translation Function
38
  def translate_tamil_to_english(text, reference=None):
39
  start = time.time()
40
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
@@ -53,56 +53,58 @@ def translate_tamil_to_english(text, reference=None):
53
 
54
  return translated, duration, rouge_l
55
 
56
- # Creative Text Generator with Perplexity
57
  def generate_creative_text(prompt, max_length=100):
58
  start = time.time()
59
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
60
- output = gen_model.generate(input_ids, max_length=max_length, do_sample=True, top_k=50, temperature=0.9)
 
 
 
61
  text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
62
  duration = round(time.time() - start, 2)
63
 
64
  tokens = text.split()
65
  repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
66
 
67
- # Perplexity calculation
68
  with torch.no_grad():
69
- input_ids = gen_tokenizer.encode(text, return_tensors="pt").to(device)
70
  outputs = gen_model(input_ids, labels=input_ids)
71
  loss = outputs.loss
72
- perplexity = torch.exp(loss).item()
73
 
74
- return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 4)
75
 
76
- # AI Image Generator with CLIP Similarity
77
  def generate_image(prompt):
78
  try:
79
  start = time.time()
80
  result = pipe(prompt)
81
  image = result.images[0].resize((256, 256))
 
82
 
 
83
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
84
  image.save(tmp_file.name)
85
 
86
  # CLIP similarity
87
  image_input = clip_preprocess(image).unsqueeze(0).to(device)
88
- text_input = clip.tokenize([prompt]).to(device)
89
-
90
  with torch.no_grad():
91
  image_features = clip_model.encode_image(image_input)
92
  text_features = clip_model.encode_text(text_input)
93
  similarity = torch.cosine_similarity(image_features, text_features).item()
94
 
95
- return tmp_file.name, round(time.time() - start, 2), round(similarity, 4)
96
-
97
  except Exception as e:
98
- return None, f"Image generation failed: {str(e)}", None
99
 
100
  # Streamlit UI
101
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
102
  st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
103
 
104
  tamil_input = st.text_area("✍️ Enter Tamil text here", height=150)
105
- reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE-L")
106
 
107
  if st.button("πŸš€ Generate Output"):
108
  if not tamil_input.strip():
@@ -115,27 +117,23 @@ if st.button("πŸš€ Generate Output"):
115
  st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
116
  if rouge_l is not None:
117
  st.markdown(f"πŸ“Š **ROUGE-L Score:** `{rouge_l}`")
118
- else:
119
- st.info("ℹ️ ROUGE-L not calculated. Reference not provided.")
120
 
121
- with st.spinner("🎨 Generating image..."):
122
- image_path, img_time, clip_score = generate_image(english_text)
123
 
124
- if image_path:
125
  st.success(f"πŸ–ΌοΈ Image generated in {img_time} seconds")
126
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
127
- st.markdown(f"πŸ” **CLIP Text-Image Similarity:** `{clip_score}`")
128
  else:
129
- st.error(image_path)
130
 
131
  with st.spinner("πŸ’‘ Generating creative text..."):
132
  creative, c_time, tokens, rep_rate, perplexity = generate_creative_text(english_text)
133
 
134
  st.success(f"✨ Creative text generated in {c_time} seconds")
135
- st.markdown("**🧠 Creative Output:**")
136
- st.text(creative)
137
- st.markdown(f"πŸ“Œ Tokens: `{tokens}`")
138
- st.markdown(f"πŸ” Repetition Rate: `{rep_rate}`")
139
  st.markdown(f"πŸ“‰ Perplexity: `{perplexity}`")
140
 
141
  st.markdown("---")
 
5
  from diffusers import StableDiffusionPipeline
6
  from rouge_score import rouge_scorer
7
  from PIL import Image
8
+ import clip
9
  import tempfile
10
  import os
11
+ import math
12
  import time
 
 
13
 
14
+ # Device setup
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+ # Translation model
18
  translator_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
19
  translator_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
20
  translator_tokenizer.src_lang = "ta_IN"
21
 
22
+ # GPT-2 for creative text
23
  gen_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
24
  gen_tokenizer = AutoTokenizer.from_pretrained("gpt2")
25
 
26
+ # Stable Diffusion v1.4
27
  pipe = StableDiffusionPipeline.from_pretrained(
28
+ "stabilityai/stable-diffusion-1-4",
29
  torch_dtype=torch.float32,
30
+ use_auth_token=os.getenv("HF_TOKEN") # set this on Hugging Face Spaces
31
  ).to(device)
32
+ pipe.safety_checker = None # Optional
33
 
34
+ # Load CLIP for image-text similarity
35
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
36
 
37
+ # Translation function
38
  def translate_tamil_to_english(text, reference=None):
39
  start = time.time()
40
  inputs = translator_tokenizer(text, return_tensors="pt").to(device)
 
53
 
54
  return translated, duration, rouge_l
55
 
56
+ # Text generation with repetition & perplexity
57
  def generate_creative_text(prompt, max_length=100):
58
  start = time.time()
59
  input_ids = gen_tokenizer.encode(prompt, return_tensors="pt").to(device)
60
+ output = gen_model.generate(
61
+ input_ids, max_length=max_length,
62
+ do_sample=True, top_k=50, temperature=0.9
63
+ )
64
  text = gen_tokenizer.decode(output[0], skip_special_tokens=True)
65
  duration = round(time.time() - start, 2)
66
 
67
  tokens = text.split()
68
  repetition_rate = sum(t1 == t2 for t1, t2 in zip(tokens, tokens[1:])) / len(tokens)
69
 
70
+ # Perplexity
71
  with torch.no_grad():
 
72
  outputs = gen_model(input_ids, labels=input_ids)
73
  loss = outputs.loss
74
+ perplexity = math.exp(loss.item())
75
 
76
+ return text, duration, len(tokens), round(repetition_rate, 4), round(perplexity, 3)
77
 
78
+ # Image generation + CLIP similarity
79
  def generate_image(prompt):
80
  try:
81
  start = time.time()
82
  result = pipe(prompt)
83
  image = result.images[0].resize((256, 256))
84
+ duration = round(time.time() - start, 2)
85
 
86
+ # Save image
87
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
88
  image.save(tmp_file.name)
89
 
90
  # CLIP similarity
91
  image_input = clip_preprocess(image).unsqueeze(0).to(device)
92
+ text_input = clip.tokenize(prompt).to(device)
 
93
  with torch.no_grad():
94
  image_features = clip_model.encode_image(image_input)
95
  text_features = clip_model.encode_text(text_input)
96
  similarity = torch.cosine_similarity(image_features, text_features).item()
97
 
98
+ return tmp_file.name, duration, round(similarity, 4)
 
99
  except Exception as e:
100
+ return None, 0, f"Image generation failed: {str(e)}"
101
 
102
  # Streamlit UI
103
  st.set_page_config(page_title="Tamil β†’ English + AI Art", layout="centered")
104
  st.title("🧠 Tamil β†’ English + 🎨 Creative Text + AI Image")
105
 
106
  tamil_input = st.text_area("✍️ Enter Tamil text here", height=150)
107
+ reference_input = st.text_input("πŸ“˜ Optional: Reference English translation for ROUGE")
108
 
109
  if st.button("πŸš€ Generate Output"):
110
  if not tamil_input.strip():
 
117
  st.markdown(f"**πŸ“ English Translation:** `{english_text}`")
118
  if rouge_l is not None:
119
  st.markdown(f"πŸ“Š **ROUGE-L Score:** `{rouge_l}`")
 
 
120
 
121
+ with st.spinner("πŸ–ΌοΈ Generating image..."):
122
+ image_path, img_time, similarity = generate_image(english_text)
123
 
124
+ if isinstance(similarity, float):
125
  st.success(f"πŸ–ΌοΈ Image generated in {img_time} seconds")
126
  st.image(Image.open(image_path), caption="AI-Generated Image", use_column_width=True)
127
+ st.markdown(f"🎯 **CLIP Text-Image Similarity:** `{similarity}`")
128
  else:
129
+ st.error(similarity)
130
 
131
  with st.spinner("πŸ’‘ Generating creative text..."):
132
  creative, c_time, tokens, rep_rate, perplexity = generate_creative_text(english_text)
133
 
134
  st.success(f"✨ Creative text generated in {c_time} seconds")
135
+ st.markdown(f"**🧠 Creative Output:** `{creative}`")
136
+ st.markdown(f"πŸ“Œ Tokens: `{tokens}`, πŸ” Repetition Rate: `{rep_rate}`")
 
 
137
  st.markdown(f"πŸ“‰ Perplexity: `{perplexity}`")
138
 
139
  st.markdown("---")