t5-ft-demo / content_gen.py
alakxender's picture
p
688d8ab
import random
import numpy as np
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import spaces
# Available models for content generation
MODEL_OPTIONS_CONTENT = {
"MX02 (mixed)": {
"model_id": "alakxender/flan-t5-corpora-mixed",
"default_prompt": "Tell me about: "
},
"MX01 (articles)": {
"model_id": "alakxender/flan-t5-news-articles",
"default_prompt": "Create an article about: "
}
}
# Cache for loaded models/tokenizers
MODEL_CACHE = {}
def get_model_and_tokenizer(model_choice):
model_dir = MODEL_OPTIONS_CONTENT[model_choice]["model_id"]
if model_dir not in MODEL_CACHE:
print(f"Loading model: {model_dir}")
tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Moving model to device: {device}")
model.to(device)
MODEL_CACHE[model_dir] = (tokenizer, model)
return MODEL_CACHE[model_dir]
def get_default_prompt(model_choice):
return MODEL_OPTIONS_CONTENT[model_choice]["default_prompt"]
@spaces.GPU()
def generate_content(prompt, max_new_tokens, num_beams, repetition_penalty, no_repeat_ngram_size, do_sample, model_choice):
tokenizer, model = get_model_and_tokenizer(model_choice)
prompt = get_default_prompt(model_choice) + prompt
inputs = tokenizer(prompt, return_tensors="pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
do_sample=do_sample,
early_stopping=False
)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Trim to the last period
if '.' in output_text:
last_period = output_text.rfind('.')
output_text = output_text[:last_period+1]
return output_text