File size: 3,979 Bytes
a02884f
 
 
 
14beccb
 
a02884f
 
06cace1
98c2919
a02884f
 
98c2919
a02884f
 
 
 
 
2e4795a
a290aa5
a02884f
 
a48a75f
 
 
 
 
7b5322f
a02884f
 
 
0d226e8
a02884f
 
 
 
0d226e8
a02884f
0d226e8
 
 
 
a02884f
3133525
0d226e8
 
 
 
 
 
a02884f
0d226e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a02884f
0d226e8
 
 
 
 
a02884f
0d226e8
 
 
 
 
 
 
 
 
a02884f
0d226e8
a02884f
 
 
 
 
 
 
0d226e8
 
 
 
 
 
 
 
a02884f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ─────────────
import pkgutil, sentence_transformers, bertopic, sys, json

# 1) Print versions & model‐list
print("ST version:", sentence_transformers.__version__)
print("BERTopic version:", bertopic.__version__)
models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]
print("ST models:", models)
sys.stdout.flush()

# 2) If StaticEmbedding is missing, alias Transformer β†’ StaticEmbedding
if "StaticEmbedding" not in models:
    from sentence_transformers.models import Transformer
    import sentence_transformers.models as _st_mod
    setattr(_st_mod, "StaticEmbedding", Transformer)
    print("πŸ”§ Shim applied: StaticEmbedding β†’ Transformer")
    sys.stdout.flush()
# ──────────────────────────────────────────────────────────────────────────────


# ── REST OF YOUR APP.PY ──────────────────────────────────────────────────────
import os, uuid
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer

# 0) Quick env dump
print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
sys.stdout.flush()

# 1) Tidy numba cache
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True)
os.environ["NUMBA_DISABLE_CACHE"] = "1"

# 2) Config from ENV
MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
MIN_TOPIC  = int(os.getenv("MIN_TOPIC_SIZE", "10"))
MAX_DOCS   = int(os.getenv("MAX_DOCS", "5000"))

# 3) Initialise once
embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
topic_model = BERTopic(
    embedding_model=embeddings,
    min_topic_size=MIN_TOPIC,
    calculate_probabilities=True,
)

# 4) Schemas
class Sentence(BaseModel):
    text: str
    start: float
    end: float
    speaker: str | None = None

class Segment(BaseModel):
    topic_id: int
    label: str | None
    keywords: List[str]
    start: float
    end: float
    probability: float | None
    sentences: List[int]

class SegmentationResponse(BaseModel):
    run_id: str
    segments: List[Segment]

# 5) FastAPI
app = FastAPI(title="CZ Topic Segmenter", version="1.0")

@app.post("/segment", response_model=SegmentationResponse)
def segment(sentences: List[Sentence]):
    if len(sentences) > MAX_DOCS:
        raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")

    docs = [s.text for s in sentences]
    topics, probs = topic_model.fit_transform(docs)

    segments, cur = [], None
    for idx, (t_id, prob) in enumerate(zip(topics, probs)):
        if cur is None or t_id != cur["topic_id"]:
            if cur:
                segments.append(cur)
            words = [w for w,_ in topic_model.get_topic(t_id)[:5]]
            cur = dict(
                topic_id   = t_id,
                label      = None if t_id == -1 else " ".join(words),
                keywords   = words,
                start      = sentences[idx].start,
                end        = sentences[idx].end,
                probability= float(prob or 0),
                sentences  = [idx],
            )
        else:
            cur["end"] = sentences[idx].end
            cur["sentences"].append(idx)
    if cur:
        segments.append(cur)

    return {"run_id": str(uuid.uuid4()), "segments": segments}
# ──────────────────────────────────────────────────────────────────────────────