# Suppress annoying warnings from this issue which cannot be solved: https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md and transformers packages import warnings warnings.filterwarnings("ignore") import re import torch import torch.nn as nn import traceback from transformers import BartTokenizer, BartForConditionalGeneration from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import numpy as np from nltk import sent_tokenize import logging import openai from tqdm import tqdm from sentence_transformers import SentenceTransformer, util from openai.error import (APIError, RateLimitError, ServiceUnavailableError, Timeout, APIConnectionError, InvalidRequestError) from tenacity import (before_sleep_log, retry, retry_if_exception_type, stop_after_delay, wait_random_exponential, stop_after_attempt) from .utils import break_down2scenes from .prompt import build_fact_prompt from .openai_api import openai_api_response logger = logging.getLogger(__name__) class OpenAIEmbedding: def __init__(self, api_key, model="text-embedding-3-large"): self.api_key = api_key self.model = model openai.api_key = api_key @retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError, ServiceUnavailableError, APIConnectionError)), wait=wait_random_exponential(max=60), stop=stop_after_attempt(10), before_sleep=before_sleep_log(logger, logging.WARNING)) def encode(self, texts, **kwargs): if isinstance(texts, str): texts = [texts] try: response = openai.Embedding.create( model=self.model, input=texts, ) # Extract embeddings from response embeddings = [item["embedding"] for item in response["data"]] return np.array(embeddings) except Exception as e: logger.error(f"Embedding API failed: {str(e)}") return None class NarrativeFactScore(): def __init__(self, model="gpt-4o-mini", split_type="fast", checkpoint=None, api_key=None, model_id="gpt-4"): self.sent_model = OpenAIEmbedding(api_key=api_key) self.split_type = split_type self.checkpoint = checkpoint self.api_key = api_key self.model_id = model_id openai.api_key = api_key if model == "gptscore": self.metric = GPTScore(model=self.model_id, api_key=self.api_key) self.metric_function = self.metric.gpt_score else: raise ValueError("NarrativeFactScore currently only supports GPTScore") def get_surrounding_sentences(self, sentence_array, ii): if ii > 0 and ii < len(sentence_array) - 1: sents = " ".join(np.array(sentence_array)[ii - 1 : ii + 1]) elif ii == 0: sents = " ".join(np.array(sentence_array)[:2]) elif ii == len(sentence_array) - 1: sents = " ".join(np.array(sentence_array)[ii - 1 :]) return sents def group_into_sections(self, sentence_array, num_sent): sectioned_sents = [] for ii in range(0, len(sentence_array), num_sent): sectioned_sents.append(" ".join(sentence_array)[ii : ii + num_sent]) return sectioned_sents def split_sent(self, text): text_list = [] if self.split_type == "fast": for t in text.split('.'): if len(t) == 0: continue text_list.append(t) return text_list elif self.split_type == "fast_comma": for t in re.split(r'[.,]', text): if len(t) == 0: continue text_list.append(t) return text_list elif self.split_type == "gpt": prompt = build_fact_prompt( prompt_template = './templates/atomic_fact.txt', input_text_list=[text], ) response = openai_api_response(prompt, model=self.model_id, api_key=self.api_key) text_list = [] for res in response.split('\n'): text_list.append(res.strip()) return text_list else: return None def score_src_hyp_long(self, srcs, hyps, kgs): all_scores = [] all_scores_per_sent = [] all_relevant_scenes = [] all_summary_chunks = [] all_feedback_list = [] # src is a list containing source documents. # hyps is a list containing predicted documents total_score = 0 for global_idx, (src, hyp) in enumerate(zip(tqdm(srcs), hyps)): src_sents = break_down2scenes(src) # Get embeddings using OpenAI API sentence_embeddings_src = self.sent_model.encode(src_sents) sentence_embeddings_kg = self.sent_model.encode(kgs) doc_scores = [] relevant_scenes = [] feedbacks = [] hyp_array = self.split_sent(hyp) for idx, hyp_sentence in enumerate(hyp_array): # Get embedding for hypothesis sentence sentence_embeddings_hyp = self.sent_model.encode(hyp_sentence) # Calculate cosine similarity scores = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_src)[0] scores_kg = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_kg)[0] sorted_idxs = np.argsort(-1 * scores) # descending order sorted_idxs_kg = np.argsort(-1 * scores_kg) # descending order similar_src_sentences = [] similar_src_sentences_kg = [] triple = '' for sorted_idx, ii in enumerate(sorted_idxs_kg[0:1]): if sorted_idx == 0: triple += f'{kgs[ii]}' else: triple += f', {kgs[ii]}' for ii in sorted_idxs[0:1]: similar_sents = src_sents[ii] similar_src_sentences.append(similar_sents) scores, feedback_list = self.metric_function(similar_src_sentences, [hyp_sentence for i in range(0, len(similar_src_sentences))], triple) score = np.max(scores) max_scene_idx = np.argmax(scores) max_scene = similar_src_sentences[max_scene_idx] feedback = feedback_list[max_scene_idx] doc_scores.append(int(score)) relevant_scenes.append(max_scene) feedbacks.append(feedback) doc_score = np.mean(doc_scores) all_scores_per_sent.append(doc_scores) all_scores.append(doc_score) all_relevant_scenes.append(relevant_scenes) all_summary_chunks.append(hyp_array) all_feedback_list.append(feedbacks) total_score += doc_score if global_idx % 100 == 99: print(f"Document mean {global_idx+1} Score: {total_score/(global_idx+1)} Score") return all_scores, all_scores_per_sent, all_relevant_scenes, all_summary_chunks, all_feedback_list class GPTScore(): def __init__(self, model="gpt-4o", api_key=None, prompt='./templates/fact_score_kg.txt'): self.max_length = 1024 self.model = model self.api_key = api_key self.prompt = prompt openai.api_key = api_key @retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError, ServiceUnavailableError, APIConnectionError, InvalidRequestError)), wait=wait_random_exponential(max=60), stop=stop_after_attempt(10), before_sleep=before_sleep_log(logger, logging.WARNING)) def gpt_inference(self, prompt): prompt_messages = [{"role": "user", "content": prompt}] try: response = openai.ChatCompletion.create( model=self.model, messages=prompt_messages, temperature=0, api_key=self.api_key ) response = response.choices[0].message.content except InvalidRequestError: response = 1 return response def gpt_score(self, srcs, tgts, kgs, batch_size=4): score_list = [] feedback_list = [] for i in range(len(srcs)): src = srcs[i] tgt = tgts[i] prompt = build_fact_prompt( prompt_template=self.prompt, input_text_list=[src, kgs, tgt], ) try: score = self.gpt_inference(prompt) if '1' in score: score_list.append(float(1)) feedback_list.append('') else: score_list.append(float(0)) feedback_list.append(score) except RuntimeError: traceback.print_exc() print(f"source: {src_list}") print(f"target: {tgt_list}") exit(0) return score_list, feedback_list