Spaces:
Running
Running
File size: 3,834 Bytes
960b1a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import logging
import torch
import random
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
class TextGenerator:
def __init__(
self,
model_name="gpt2",
device="cuda",
max_new_tokens=50,
temperature=1.0,
top_p=0.95,
seed=None
):
self.model_name = model_name
self.device = device
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self.top_p = top_p
self.seed = seed
logging.info(f"[TextGenerator] Загрузка модели {model_name} на {device} ...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
if seed is not None:
set_seed(seed)
logging.info(f"[TextGenerator] Сид генерации установлен через transformers.set_seed({seed})")
else:
logging.info("[TextGenerator] Сид генерации не установлен (seed=None)")
# --- Примеры для few-shot обучения ---
self.fewshot_examples = [
("happy", "We finally made it!", "We finally made it! I’ve never felt so alive and proud of what we accomplished."),
("sad", "He didn't come back.", "He didn't come back. I waited all night, hoping to see him again."),
("anger", "Why would you do that?", "Why would you do that? You had no right to interfere!"),
("fear", "Did you hear that?", "Did you hear that? Something’s moving outside the window..."),
("surprise", "Oh wow, really?", "Oh wow, really? I didn’t see that coming at all!"),
("disgust", "That smell is awful.", "That smell is awful. I feel like I’m going to be sick."),
("neutral", "Let's meet at noon.", "Let's meet at noon. We’ll have plenty of time to talk then.")
]
def build_prompt(self, emotion: str, partial_text: str) -> str:
few_shot = random.sample(self.fewshot_examples, 2)
examples_str = ""
for emo, text, cont in few_shot:
examples_str += (
f"Example:\n"
f"Emotion: {emo}\n"
f"Text: {text}\n"
f"Continuation: {cont}\n\n"
)
prompt = (
"You are a helpful assistant that generates emotionally-aligned sentence continuations.\n"
"You must include the original sentence in the output, and then continue it in a fluent and emotionally appropriate way.\n\n"
f"{examples_str}"
f"Now try:\n"
f"Emotion: {emotion}\n"
f"Text: {partial_text}\n"
f"Continuation:"
)
return prompt
def generate_text(self, emotion: str, partial_text: str = "") -> str:
prompt = self.build_prompt(emotion, partial_text)
logging.debug(f"[TextGenerator] prompt:\n{prompt}")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
output_ids = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=True,
top_p=self.top_p,
temperature=self.temperature,
pad_token_id=self.tokenizer.eos_token_id
)
full_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
logging.debug(f"[TextGenerator] decoded:\n{full_text}")
# Вытаскиваем то, что идёт после последнего "Continuation:"
if "Continuation:" in full_text:
result = full_text.split("Continuation:")[-1].strip()
else:
result = full_text.strip()
result = result.split("\n")[0].strip()
return result
|