24Sureshkumar commited on
Commit
2bd9593
·
verified ·
1 Parent(s): 2c77ef8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -31
app.py CHANGED
@@ -7,12 +7,12 @@ import io
7
  import os
8
  from typing import Tuple
9
 
10
- # Load HF token
11
- HF_API_KEY = os.getenv("HF_API_KEY") or "your_hf_token_here" # Replace this with your token if local
12
  if not HF_API_KEY:
13
  raise ValueError("HF_API_KEY is not set.")
14
 
15
- # Hugging Face image model
16
  IMAGE_GEN_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
17
  HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"}
18
 
@@ -20,51 +20,54 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  # Translation model (Tamil to English)
22
  translator_model = "Helsinki-NLP/opus-mt-mul-en"
23
- translator = MarianMTModel.from_pretrained(translator_model).to(device)
24
  translator_tokenizer = MarianTokenizer.from_pretrained(translator_model)
 
25
 
26
  # Text generation model
27
- generator_model = "EleutherAI/gpt-neo-1.3B"
28
- generator = AutoModelForCausalLM.from_pretrained(generator_model).to(device)
29
- generator_tokenizer = AutoTokenizer.from_pretrained(generator_model)
30
- generator_tokenizer.pad_token = generator_tokenizer.eos_token
31
 
 
32
  def translate_tamil_to_english(text: str) -> str:
33
  inputs = translator_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
34
- output = translator.generate(**inputs)
35
- return translator_tokenizer.decode(output[0], skip_special_tokens=True)
36
 
 
37
  def generate_text(prompt: str) -> str:
38
- inputs = generator_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
39
- output = generator.generate(**inputs, max_length=100, num_return_sequences=1)
40
- return generator_tokenizer.decode(output[0], skip_special_tokens=True)
41
 
 
42
  def generate_image(prompt: str) -> Image.Image:
43
  response = requests.post(IMAGE_GEN_URL, headers=HEADERS, json={"inputs": prompt})
44
- try:
45
- if response.status_code == 200 and response.headers["content-type"].startswith("image"):
46
- return Image.open(io.BytesIO(response.content))
47
- except Exception as e:
48
- print("Image generation failed:", e)
49
- return Image.new("RGB", (300, 300), color="gray")
50
 
 
51
  def process_input(tamil_text: str) -> Tuple[str, str, Image.Image]:
52
- english_text = translate_tamil_to_english(tamil_text)
53
- creative_text = generate_text(english_text)
54
- image = generate_image(english_text)
55
- return english_text, creative_text, image
56
 
57
- # Gradio app
58
  with gr.Blocks() as demo:
59
- gr.Markdown("## Tamil to English Translator with Text and Image Generator")
60
 
61
- tamil_input = gr.Textbox(label="Enter Tamil Text")
62
- translate_btn = gr.Button("Translate & Generate")
 
63
 
64
- english_output = gr.Textbox(label="Translated English")
65
- creative_output = gr.Textbox(label="Creative Text")
66
- image_output = gr.Image(label="Generated Image")
67
 
68
- translate_btn.click(fn=process_input, inputs=tamil_input, outputs=[english_output, creative_output, image_output])
69
 
70
  demo.launch()
 
7
  import os
8
  from typing import Tuple
9
 
10
+ # Load Hugging Face token
11
+ HF_API_KEY = os.getenv("HF_API_KEY") or "your_hf_token_here"
12
  if not HF_API_KEY:
13
  raise ValueError("HF_API_KEY is not set.")
14
 
15
+ # Hugging Face inference API endpoint
16
  IMAGE_GEN_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
17
  HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"}
18
 
 
20
 
21
  # Translation model (Tamil to English)
22
  translator_model = "Helsinki-NLP/opus-mt-mul-en"
 
23
  translator_tokenizer = MarianTokenizer.from_pretrained(translator_model)
24
+ translator = MarianMTModel.from_pretrained(translator_model).to(device)
25
 
26
  # Text generation model
27
+ text_model = "EleutherAI/gpt-neo-1.3B"
28
+ text_tokenizer = AutoTokenizer.from_pretrained(text_model)
29
+ text_generator = AutoModelForCausalLM.from_pretrained(text_model).to(device)
30
+ text_tokenizer.pad_token = text_tokenizer.eos_token
31
 
32
+ # Step 1: Tamil to English translation
33
  def translate_tamil_to_english(text: str) -> str:
34
  inputs = translator_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
35
+ outputs = translator.generate(**inputs)
36
+ return translator_tokenizer.decode(outputs[0], skip_special_tokens=True)
37
 
38
+ # Step 2: Generate creative text
39
  def generate_text(prompt: str) -> str:
40
+ inputs = text_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)
41
+ outputs = text_generator.generate(**inputs, max_length=100, num_return_sequences=1)
42
+ return text_tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
+ # Step 3: Generate image
45
  def generate_image(prompt: str) -> Image.Image:
46
  response = requests.post(IMAGE_GEN_URL, headers=HEADERS, json={"inputs": prompt})
47
+ if response.status_code == 200 and response.headers.get("content-type", "").startswith("image"):
48
+ return Image.open(io.BytesIO(response.content))
49
+ else:
50
+ return Image.new("RGB", (512, 512), color="gray")
 
 
51
 
52
+ # Master function
53
  def process_input(tamil_text: str) -> Tuple[str, str, Image.Image]:
54
+ english = translate_tamil_to_english(tamil_text)
55
+ creative = generate_text(english)
56
+ image = generate_image(english)
57
+ return english, creative, image
58
 
59
+ # Gradio UI using Blocks API
60
  with gr.Blocks() as demo:
61
+ gr.Markdown("## 🌍 Tamil to English | Text & Image Generator")
62
 
63
+ with gr.Row():
64
+ tamil_input = gr.Textbox(label="📝 Enter Tamil Text", placeholder="உங்கள் உரையை இங்கே உள்ளிடவும்...", lines=2)
65
+ generate_btn = gr.Button("Translate & Generate")
66
 
67
+ english_output = gr.Textbox(label="🇬🇧 Translated English")
68
+ creative_output = gr.Textbox(label=" Generated Text")
69
+ image_output = gr.Image(label="🖼️ Generated Image", type="pil")
70
 
71
+ generate_btn.click(fn=process_input, inputs=tamil_input, outputs=[english_output, creative_output, image_output])
72
 
73
  demo.launch()