|
import json |
|
import os |
|
import hashlib |
|
import time |
|
from json import JSONDecodeError |
|
|
|
CACHE_DIR = "datasets/feature_spans_cache" |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
import pandas as pd |
|
|
|
|
|
_g2v_df = pd.read_csv("datasets/gram2vec_feats.csv") |
|
GRAM2VEC_SET = set(_g2v_df['gram2vec_feats'].unique()) |
|
MAX_ATTEMPTS = 3 |
|
WAIT_SECONDS = 2 |
|
|
|
|
|
CACHE_VERSION = 2 |
|
|
|
def _feat_hash(feature: str, text: str) -> str: |
|
blob = json.dumps({ |
|
"version": CACHE_VERSION, |
|
"text": text, |
|
"features": sorted(feature) |
|
}, sort_keys=True).encode() |
|
return hashlib.md5(blob).hexdigest() |
|
|
|
|
|
def generate_feature_spans(client, text: str, features: list[str]) -> str: |
|
print("Calling OpenAI to extract spans") |
|
""" |
|
Call to OpenAI to extract spans. Returns a JSON string. |
|
""" |
|
prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature. |
|
|
|
Important: |
|
- The headers like "Document 1:" etc are NOT part of the original text — ignore them. |
|
- For each feature, even if there is no match, return an empty list. |
|
- Only return exact phrases from the text. |
|
|
|
Respond in JSON format like: |
|
{{ |
|
"feature1": ["span1", "span2"], |
|
"feature2": [], |
|
… |
|
}} |
|
|
|
Text: |
|
\"\"\"{text}\"\"\" |
|
|
|
Style Features: |
|
{features} |
|
""" |
|
response = client.chat.completions.create( |
|
model="gpt-4", |
|
messages=[{"role":"user","content":prompt}], |
|
temperature=0.3, |
|
) |
|
return response.choices[0].message.content |
|
|
|
def generate_feature_spans_with_retries(client, text: str, features: list[str]) -> dict: |
|
""" |
|
Calls `generate_feature_spans` with retries on failure. |
|
Returns the parsed JSON dict mapping feature->list[spans]. |
|
""" |
|
for attempt in range(MAX_ATTEMPTS): |
|
try: |
|
response_str = generate_feature_spans(client, text, features) |
|
result = json.loads(response_str) |
|
return result |
|
except (JSONDecodeError, ValueError) as e: |
|
print(f"Attempt {attempt+1} failed: {e}") |
|
if attempt < MAX_ATTEMPTS - 1: |
|
wait_sec = WAIT_SECONDS * (2 ** attempt) |
|
print(f"Retrying after {wait_sec} seconds...") |
|
time.sleep(wait_sec) |
|
raise RuntimeError("All retry attempts failed for OpenAI call.") |
|
|
|
|
|
def generate_feature_spans_cached(client, text: str, features: list[str], role: str = "mystery" ) -> dict: |
|
""" |
|
Computes a cache key from text + feature list, |
|
then either loads or calls the API and saves to disk. |
|
Returns the parsed JSON dict mapping feature->list[spans]. |
|
""" |
|
print(f"Generating spans for ({role})") |
|
|
|
role = role.replace(" ", "_").replace("/", "_").replace("-", "_") |
|
print(f"Cache dir: {CACHE_DIR}") |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
cache_path = os.path.join(CACHE_DIR, f"{role}.json") |
|
if os.path.exists(cache_path): |
|
print(f"Cache hit....") |
|
with open(cache_path) as f: |
|
cache: dict[str, dict] = json.load(f) |
|
else: |
|
cache = {} |
|
result: dict[str, list[str]] = {} |
|
missing_feats: list[str] = [] |
|
|
|
for feat in features: |
|
if feat == "None": |
|
result[feat] = [] |
|
continue |
|
|
|
h = _feat_hash(feat, text) |
|
if h in cache: |
|
result[feat] = cache[h]["spans"] |
|
else: |
|
missing_feats.append(feat) |
|
|
|
if missing_feats: |
|
|
|
mapping = generate_feature_spans_with_retries(client, text, missing_feats) |
|
|
|
for feat in missing_feats: |
|
h = _feat_hash(feat, text) |
|
spans = mapping.get(feat) |
|
cache[h] = { |
|
"feature": feat, |
|
"spans": spans |
|
} |
|
result[feat] = spans |
|
|
|
|
|
with open(cache_path, "w") as f: |
|
json.dump(cache, f, indent=2) |
|
return result |
|
|
|
|
|
def split_features(all_feats): |
|
""" |
|
Given a list of mixed features, returns two lists: |
|
- llm_feats: those NOT in the Gram2Vec CSV |
|
- g2v_feats: those present in the CSV |
|
""" |
|
g2v_feats = [feat for feat in all_feats if feat in GRAM2VEC_SET] |
|
llm_feats = [feat for feat in all_feats if feat not in GRAM2VEC_SET] |
|
return llm_feats, g2v_feats |