File size: 4,542 Bytes
3d73c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac7facf
3d73c8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import json
import os
import hashlib
import time
from json import JSONDecodeError

CACHE_DIR = "datasets/feature_spans_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
import pandas as pd

#read and create the Gram2Vec feature set once
_g2v_df      = pd.read_csv("datasets/gram2vec_feats.csv")
GRAM2VEC_SET = set(_g2v_df['gram2vec_feats'].unique())
MAX_ATTEMPTS = 3
WAIT_SECONDS = 2

# Bump this whenever there is a change prompt, feature space, etc...
CACHE_VERSION = 2

def _feat_hash(feature: str, text: str) -> str:
    blob = json.dumps({
        "version": CACHE_VERSION,
        "text": text,
        "features": sorted(feature)
    }, sort_keys=True).encode()
    return hashlib.md5(blob).hexdigest()


def generate_feature_spans(client, text: str, features: list[str]) -> str:
    print("Calling OpenAI to extract spans")
    """
    Call to OpenAI to extract spans. Returns a JSON string.
    """
    prompt = f"""You are a linguistic specialist. Given a writing sample and a list of descriptive features, identify the exact text spans that demonstrate each feature.
    
    Important:
    - The headers like "Document 1:" etc are NOT part of the original text — ignore them.
    - For each feature, even if there is no match, return an empty list.
    - Only return exact phrases from the text.

    Respond in JSON format like:
    {{
      "feature1": ["span1", "span2"],
      "feature2": [],

    }}

    Text:
    \"\"\"{text}\"\"\"

    Style Features:
    {features}
    """
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[{"role":"user","content":prompt}],
        temperature=0.3,
    )
    return response.choices[0].message.content

def generate_feature_spans_with_retries(client, text: str, features: list[str]) -> dict:
    """
    Calls `generate_feature_spans` with retries on failure.
    Returns the parsed JSON dict mapping feature->list[spans].
    """
    for attempt in range(MAX_ATTEMPTS):
        try:
            response_str = generate_feature_spans(client, text, features)
            result = json.loads(response_str)
            return result
        except (JSONDecodeError, ValueError) as e:
            print(f"Attempt {attempt+1} failed: {e}")
            if attempt < MAX_ATTEMPTS - 1:
                wait_sec = WAIT_SECONDS * (2 ** attempt)
                print(f"Retrying after {wait_sec} seconds...")
                time.sleep(wait_sec)
    raise RuntimeError("All retry attempts failed for OpenAI call.")


def generate_feature_spans_cached(client, text: str, features: list[str], role: str = "mystery" ) -> dict:
    """
    Computes a cache key from text + feature list,
    then either loads or calls the API and saves to disk.
    Returns the parsed JSON dict mapping feature->list[spans].
    """
    print(f"Generating spans for ({role})")
    # print(f"feature list {features}")
    role = role.replace(" ", "_").replace("/", "_").replace("-", "_")
    print(f"Cache dir: {CACHE_DIR}")
    os.makedirs(CACHE_DIR, exist_ok=True)
    cache_path = os.path.join(CACHE_DIR, f"{role}.json")
    if os.path.exists(cache_path):
        print(f"Cache hit....")
        with open(cache_path) as f:
            cache: dict[str, dict] = json.load(f)
    else:
        cache = {}
    result: dict[str, list[str]] = {}
    missing_feats: list[str] = []

    for feat in features:
        if feat == "None":
            result[feat] = []
            continue
        
        h = _feat_hash(feat, text)
        if h in cache:
            result[feat] = cache[h]["spans"]
        else:
            missing_feats.append(feat)

    if missing_feats:

        mapping = generate_feature_spans_with_retries(client, text, missing_feats)
        # 4) update cache & result for each missing feature
        for feat in missing_feats:
            h = _feat_hash(feat, text)
            spans = mapping.get(feat) 
            cache[h] = {
                "feature": feat,
                "spans": spans
            }
            result[feat] = spans

        # 5) write back the combined cache
        with open(cache_path, "w") as f:
            json.dump(cache, f, indent=2)
    return result


def split_features(all_feats):
    """
    Given a list of mixed features, returns two lists:
    - llm_feats: those NOT in the Gram2Vec CSV
    - g2v_feats: those present in the CSV
    """
    g2v_feats = [feat for feat in all_feats if feat in GRAM2VEC_SET]
    llm_feats = [feat for feat in all_feats if feat not in GRAM2VEC_SET]
    return llm_feats, g2v_feats