Spaces:
Sleeping
Sleeping
| from typing import Any, Optional | |
| from omagent_core.models.encoders.base import EncoderBase | |
| from omagent_core.utils.error import VQLError | |
| from omagent_core.utils.registry import registry | |
| from pydantic import field_validator | |
| from pymilvus import CollectionSchema, DataType, FieldSchema | |
| from .milvus_handler import MilvusHandler | |
| class VideoHandler(MilvusHandler): | |
| collection_name: str | |
| text_encoder: Optional[EncoderBase] = None | |
| dim: int = None | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = "allow" | |
| arbitrary_types_allowed = True | |
| def init_encoder(cls, text_encoder): | |
| if isinstance(text_encoder, EncoderBase): | |
| return text_encoder | |
| elif isinstance(text_encoder, dict): | |
| return registry.get_encoder(text_encoder.get("name"))(**text_encoder) | |
| else: | |
| raise ValueError("text_encoder must be EncoderBase or Dict") | |
| def __init__(self, **data: Any) -> None: | |
| super().__init__(**data) | |
| self.dim = self.text_encoder.dim | |
| _uid = FieldSchema( | |
| name="_uid", dtype=DataType.INT64, is_primary=True, auto_id=True | |
| ) | |
| video_md5 = FieldSchema( | |
| name="video_md5", dtype=DataType.VARCHAR, max_length=100 | |
| ) | |
| content = FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535) | |
| content_vector = FieldSchema( | |
| name="content_vector", dtype=DataType.FLOAT_VECTOR, dim=self.dim | |
| ) | |
| start_time = FieldSchema( | |
| name="start_time", | |
| dtype=DataType.FLOAT, | |
| ) | |
| end_time = FieldSchema( | |
| name="end_time", | |
| dtype=DataType.FLOAT, | |
| ) | |
| schema = CollectionSchema( | |
| fields=[_uid, video_md5, content, content_vector, start_time, end_time], | |
| description="video summary vector DB", | |
| enable_dynamic_field=True, | |
| ) | |
| self.make_collection(self.collection_name, schema) | |
| def text_add(self, video_md5, content, start_time, end_time): | |
| if self.text_encoder is None: | |
| raise VQLError(500, detail="Missing text_encoder") | |
| content_vector = self.text_encoder.infer([content])[0] | |
| # upload_data = [ | |
| # [video_md5], | |
| # [content], | |
| # [content_vector], | |
| # [start_time], | |
| # [end_time], | |
| # ] | |
| upload_data = [ | |
| { | |
| "video_md5": video_md5, | |
| "content": content, | |
| "content_vector": content_vector, | |
| "start_time": start_time, | |
| "end_time": end_time, | |
| } | |
| ] | |
| add_detail = self.do_add(self.collection_name, upload_data) | |
| # assert add_detail.succ_count == len(upload_data) | |
| def text_match( | |
| self, | |
| video_md5, | |
| content, | |
| threshold: float, | |
| start_time=None, | |
| end_time=None, | |
| res_size: int = 100, | |
| ): | |
| # search_query = {"size": res_size, "sort": [{"_score": "desc"}], | |
| # "include": ["content", "start_time", "end_time"]} | |
| filter_expr = "" | |
| if video_md5 is not None: | |
| filter_expr = f"video_md5=='{video_md5}'" | |
| if start_time is not None and end_time is not None: | |
| filter_expr += f" and (start_time>={max(0, start_time - 10)} and end_time<={end_time + 10})" | |
| elif start_time is not None: | |
| filter_expr += f" and start_time>={max(0, start_time - 10)}" | |
| elif end_time is not None: | |
| filter_expr += f" and end_time<={end_time + 10}" | |
| # text retrieve stage | |
| content_vector = self.text_encoder.infer([content])[0] | |
| match_res = self.match( | |
| collection_name=self.collection_name, | |
| query_vectors=[content_vector], | |
| query_field="content_vector", | |
| output_fields=["content", "start_time", "end_time"], | |
| res_size=res_size, | |
| threshold=threshold, | |
| filter_expr=filter_expr, | |
| ) | |
| output = [] | |
| for match in match_res[0]: | |
| print(match) | |
| output.append(match["entity"]) | |
| return output | |