explainability-tool-for-aa / utils /llm_feat_utils.py
Milad Alshomary
updates
ac7facf
import json
import os
import hashlib
import time
from json import JSONDecodeError
CACHE_DIR = "datasets/feature_spans_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
import pandas as pd
#read and create the Gram2Vec feature set once
_g2v_df = pd.read_csv("datasets/gram2vec_feats.csv")
GRAM2VEC_SET = set(_g2v_df['gram2vec_feats'].unique())
MAX_ATTEMPTS = 3
WAIT_SECONDS = 2
# Bump this whenever there is a change prompt, feature space, etc...
CACHE_VERSION = 2
def _feat_hash(feature: str, text: str) -> str:
blob = json.dumps({
"version": CACHE_VERSION,
"text": text,
"features": sorted(feature)
}, sort_keys=True).encode()
return hashlib.md5(blob).hexdigest()
def generate_feature_spans(client, text: str, features: list[str]) -> str:
print("Calling OpenAI to extract spans")
"""
Call to OpenAI to extract spans. Returns a JSON string.
"""
prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
Important:
- The headers like "Document 1:" etc are NOT part of the original text — ignore them.
- For each feature, even if there is no match, return an empty list.
- Only return exact phrases from the text.
Respond in JSON format like:
{{
"feature1": ["span1", "span2"],
"feature2": [],
}}
Text:
\"\"\"{text}\"\"\"
Style Features:
{features}
"""
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role":"user","content":prompt}],
temperature=0.3,
)
return response.choices[0].message.content
def generate_feature_spans_with_retries(client, text: str, features: list[str]) -> dict:
"""
Calls `generate_feature_spans` with retries on failure.
Returns the parsed JSON dict mapping feature->list[spans].
"""
for attempt in range(MAX_ATTEMPTS):
try:
response_str = generate_feature_spans(client, text, features)
result = json.loads(response_str)
return result
except (JSONDecodeError, ValueError) as e:
print(f"Attempt {attempt+1} failed: {e}")
if attempt < MAX_ATTEMPTS - 1:
wait_sec = WAIT_SECONDS * (2 ** attempt)
print(f"Retrying after {wait_sec} seconds...")
time.sleep(wait_sec)
raise RuntimeError("All retry attempts failed for OpenAI call.")
def generate_feature_spans_cached(client, text: str, features: list[str], role: str = "mystery" ) -> dict:
"""
Computes a cache key from text + feature list,
then either loads or calls the API and saves to disk.
Returns the parsed JSON dict mapping feature->list[spans].
"""
print(f"Generating spans for ({role})")
# print(f"feature list {features}")
role = role.replace(" ", "_").replace("/", "_").replace("-", "_")
print(f"Cache dir: {CACHE_DIR}")
os.makedirs(CACHE_DIR, exist_ok=True)
cache_path = os.path.join(CACHE_DIR, f"{role}.json")
if os.path.exists(cache_path):
print(f"Cache hit....")
with open(cache_path) as f:
cache: dict[str, dict] = json.load(f)
else:
cache = {}
result: dict[str, list[str]] = {}
missing_feats: list[str] = []
for feat in features:
if feat == "None":
result[feat] = []
continue
h = _feat_hash(feat, text)
if h in cache:
result[feat] = cache[h]["spans"]
else:
missing_feats.append(feat)
if missing_feats:
mapping = generate_feature_spans_with_retries(client, text, missing_feats)
# 4) update cache & result for each missing feature
for feat in missing_feats:
h = _feat_hash(feat, text)
spans = mapping.get(feat)
cache[h] = {
"feature": feat,
"spans": spans
}
result[feat] = spans
# 5) write back the combined cache
with open(cache_path, "w") as f:
json.dump(cache, f, indent=2)
return result
def split_features(all_feats):
"""
Given a list of mixed features, returns two lists:
- llm_feats: those NOT in the Gram2Vec CSV
- g2v_feats: those present in the CSV
"""
g2v_feats = [feat for feat in all_feats if feat in GRAM2VEC_SET]
llm_feats = [feat for feat in all_feats if feat not in GRAM2VEC_SET]
return llm_feats, g2v_feats