File size: 4,176 Bytes
cfa9e5f c18a5f4 cfa9e5f eb7c2d4 cfa9e5f 7b12b7f cfa9e5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "facebook/opt-1.3b"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
)
def generate_answer_chat(query, options, retrieved_chunks, model=model, tokenizer=tokenizer):
"""
Generates an answer using the retrieved context, formatted as a conversation
to better suit Llama 2 7B Chat's conversational tuning.
"""
# Format each retrieved chunk as a numbered paragraph.
paragraphs = [f"Paragraph {idx+1}: {chunk}" for idx, chunk in enumerate(retrieved_chunks)]
context = "\n\n".join(paragraphs)
# Create a conversational prompt.
system_message = (
"System: You are a telecom regulations expert. Answer using the information provided in the context. Start directly by Giving the best choice from options"
)
context_message = f"Context:\n{context}"
user_message = f"User: {query}\nOptions: " + " | ".join(options)
assistant_cue = "Assistant: "
prompt = "\n\n".join([system_message, context_message, user_message, assistant_cue])
# Determine the model type: seq2seq or causal.
model_type = "seq2seq" if getattr(model.config, "is_encoder_decoder", False) else "causal"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=128,
num_return_sequences=1,
no_repeat_ngram_size=2
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if model_type == "causal":
# Attempt to extract only the assistant's response.
answer_start = generated_text.find("Assistant:")
if answer_start != -1:
answer = generated_text[answer_start + len("Assistant:"):].strip()
else:
answer = generated_text[len(prompt):].strip()
return answer
else:
return generated_text.strip()
def generate_answer(query, retrieved_chunks, model=model, tokenizer=tokenizer):
"""
Generates an answer using the retrieved context.
For causal models, the prompt is included in the output so it must be removed.
For seq2seq models, the output is directly the generated answer.
"""
# Format each chunk as a separate paragraph with a numbered prefix.
paragraphs = [f"Paragraph {idx+1}: {chunk}" for idx, chunk in enumerate(retrieved_chunks)]
context = "\n\n".join(paragraphs)
prompt = (f"You are a telecom regulations expert. Using the following context, answer the question:\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n")
model_type = "seq2seq" if getattr(model.config, "is_encoder_decoder", False) else "causal"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
num_return_sequences=1,
no_repeat_ngram_size=2
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# For causal models, remove the prompt from the output.
if model_type == "causal":
return generated_text[len(prompt):].strip()
else:
return generated_text.strip()
def generate_norag(query, model, tokenizer):
"""
Generates an answer without additional context.
"""
prompt = f"Answer the question:\n\nQuestion: {query}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Generate output with a specified maximum number of new tokens.
outputs = model.generate(
**inputs,
max_new_tokens=128, # Specifies the number of tokens to generate.
num_return_sequences=1,
no_repeat_ngram_size=2
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
model_type = "seq2seq" if getattr(model.config, "is_encoder_decoder", False) else "causal"
if model_type == "causal":
return generated_text[len(prompt):].strip()
else: # For seq2seq models
return generated_text.strip()
|