insightflowv2 / app.py
Yeetek's picture
Update app.py
a02884f verified
raw
history blame
3.98 kB
# ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ─────────────
import pkgutil, sentence_transformers, bertopic, sys, json
# 1) Print versions & model‐list
print("ST version:", sentence_transformers.__version__)
print("BERTopic version:", bertopic.__version__)
models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)]
print("ST models:", models)
sys.stdout.flush()
# 2) If StaticEmbedding is missing, alias Transformer β†’ StaticEmbedding
if "StaticEmbedding" not in models:
from sentence_transformers.models import Transformer
import sentence_transformers.models as _st_mod
setattr(_st_mod, "StaticEmbedding", Transformer)
print("πŸ”§ Shim applied: StaticEmbedding β†’ Transformer")
sys.stdout.flush()
# ──────────────────────────────────────────────────────────────────────────────
# ── REST OF YOUR APP.PY ──────────────────────────────────────────────────────
import os, uuid
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from bertopic import BERTopic
from sentence_transformers import SentenceTransformer
# 0) Quick env dump
print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]}))
sys.stdout.flush()
# 1) Tidy numba cache
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True)
os.environ["NUMBA_DISABLE_CACHE"] = "1"
# 2) Config from ENV
MODEL_NAME = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech")
MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10"))
MAX_DOCS = int(os.getenv("MAX_DOCS", "5000"))
# 3) Initialise once
embeddings = SentenceTransformer(MODEL_NAME, cache_folder="/tmp/hfcache")
topic_model = BERTopic(
embedding_model=embeddings,
min_topic_size=MIN_TOPIC,
calculate_probabilities=True,
)
# 4) Schemas
class Sentence(BaseModel):
text: str
start: float
end: float
speaker: str | None = None
class Segment(BaseModel):
topic_id: int
label: str | None
keywords: List[str]
start: float
end: float
probability: float | None
sentences: List[int]
class SegmentationResponse(BaseModel):
run_id: str
segments: List[Segment]
# 5) FastAPI
app = FastAPI(title="CZ Topic Segmenter", version="1.0")
@app.post("/segment", response_model=SegmentationResponse)
def segment(sentences: List[Sentence]):
if len(sentences) > MAX_DOCS:
raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})")
docs = [s.text for s in sentences]
topics, probs = topic_model.fit_transform(docs)
segments, cur = [], None
for idx, (t_id, prob) in enumerate(zip(topics, probs)):
if cur is None or t_id != cur["topic_id"]:
if cur:
segments.append(cur)
words = [w for w,_ in topic_model.get_topic(t_id)[:5]]
cur = dict(
topic_id = t_id,
label = None if t_id == -1 else " ".join(words),
keywords = words,
start = sentences[idx].start,
end = sentences[idx].end,
probability= float(prob or 0),
sentences = [idx],
)
else:
cur["end"] = sentences[idx].end
cur["sentences"].append(idx)
if cur:
segments.append(cur)
return {"run_id": str(uuid.uuid4()), "segments": segments}
# ──────────────────────────────────────────────────────────────────────────────