meisaicheck-api / routes /predict.py
Vu Minh Chien
update rule unit
3197697
raw
history blame
23.4 kB
import os
import time
import shutil
import pandas as pd
import traceback
import sys
import numpy as np
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
from fastapi.responses import FileResponse
from custom_auth import get_current_user_from_token
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
# Add the path to import modules from meisai-check-ai
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "meisai-check-ai"))
from mapping_lib.standard_subject_data_mapper import StandardSubjectDataMapper
from mapping_lib.subject_similarity_mapper import SubjectSimilarityMapper
from mapping_lib.sub_subject_similarity_mapper import SubSubjectSimilarityMapper
from mapping_lib.name_similarity_mapper import NameSimilarityMapper
from mapping_lib.sub_subject_and_name_data_mapper import SubSubjectAndNameDataMapper
from mapping_lib.sub_subject_location_data_mapper import SubSubjectLocationDataMapper
from mapping_lib.abstract_similarity_mapper import AbstractSimilarityMapper
from mapping_lib.name_and_abstract_mapper import NameAndAbstractDataMapper
from mapping_lib.unit_mapper import UnitMapper
from mapping_lib.base_dictionary_mapper import BaseDictionaryMapper
from common_lib.data_utilities import fillna_with_space
from common_lib.string_utilities import (
preprocess_text,
ConversionType,
ConversionSettings,
)
from config import UPLOAD_DIR, OUTPUT_DIR
from models import (
EmbeddingRequest,
PredictRawRequest,
PredictRawResponse,
PredictRecord,
PredictResult,
)
router = APIRouter()
@router.post("/predict")
async def predict(
current_user=Depends(get_current_user_from_token),
file: UploadFile = File(...),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Process an input CSV file and return standardized names (requires authentication)
"""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files are supported")
# Save uploaded file
timestamp = int(time.time())
input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv")
output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv")
try:
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
finally:
file.file.close()
try:
# Load input data
start_time = time.time()
df_input_data = pd.read_csv(input_file_path)
# Preprocess data like in meisai-check-ai/predict.py
df_input_data["元名称"] = df_input_data["名称"]
df_input_data["名称"] = df_input_data["名称"].apply(
lambda x: (
preprocess_text(
x,
convert_kana=ConversionType.Z2H,
convert_alphabet=ConversionType.Z2H,
convert_digit=ConversionType.Z2H,
)
if pd.notna(x)
else ""
)
)
# Ensure basic columns exist with default values
basic_columns = {
"シート名": "",
"行": "",
"科目": "",
"中科目": "",
"分類": "",
"名称": "",
"単位": "",
"摘要": "",
"備考": "",
}
for col, default_value in basic_columns.items():
if col not in df_input_data.columns:
df_input_data[col] = default_value
# SubjectSimilarityMapper
try:
if sentence_service.df_subject_map_data is not None:
subject_similarity_mapper = SubjectSimilarityMapper(
cached_embedding_helper=sentence_service.subject_cached_embedding_helper,
df_map_data=sentence_service.df_subject_map_data,
)
list_input_subject = df_input_data["科目"].unique()
df_subject_data = pd.DataFrame(list_input_subject, columns=["科目"])
subject_similarity_mapper.predict_input(df_input_data=df_subject_data)
output_subject_map = dict(zip(df_subject_data["科目"], df_subject_data["出力_科目"]))
df_input_data["標準科目"] = df_input_data["科目"].map(output_subject_map)
df_input_data["出力_科目"] = df_input_data["標準科目"]
fillna_with_space(df_input_data)
except Exception as e:
print(f"Error processing SubjectSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# StandardSubjectDataMapper
try:
if sentence_service.df_standard_subject_map_data is not None:
standard_subject_data_mapper = StandardSubjectDataMapper(
df_map_data=sentence_service.df_standard_subject_map_data
)
df_output_data = standard_subject_data_mapper.map_data(
df_input_data=df_input_data, input_key_columns=["出力_科目"], in_place=True
)
fillna_with_space(df_output_data)
else:
df_output_data = df_input_data.copy()
except Exception as e:
print(f"Error processing StandardSubjectDataMapper: {e}")
# Continue with original data if standard subject mapping fails
df_output_data = df_input_data.copy()
# SubSubjectSimilarityMapper
try:
if sentence_service.df_sub_subject_map_data is not None:
sub_subject_similarity_mapper = SubSubjectSimilarityMapper(
cached_embedding_helper=sentence_service.sub_subject_cached_embedding_helper,
df_map_data=sentence_service.df_sub_subject_map_data,
)
df_input_sub_subject = df_output_data[
["科目", "標準科目", "出力_科目", "中科目", "分類"]
].drop_duplicates()
sub_subject_similarity_mapper.predict_input(df_input_data=df_input_sub_subject)
sub_subject_map_key_columns = ["科目", "標準科目", "出力_科目", "中科目", "分類"]
sub_subject_map_data_columns = [
"出力_基準中科目",
"出力_中科目類似度",
"出力_中科目",
"外部・内部区分",
]
sub_subject_data_mapper = BaseDictionaryMapper(
df_input_sub_subject, sub_subject_map_key_columns, sub_subject_map_data_columns
)
sub_subject_data_mapper.map_data(
df_input_data=df_output_data,
input_key_columns=sub_subject_map_key_columns,
in_place=True,
)
fillna_with_space(df_output_data)
except Exception as e:
print(f"Error processing SubSubjectSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# NameSimilarityMapper
try:
if sentence_service.df_name_map_data is not None:
name_sentence_mapper = NameSimilarityMapper(
cached_embedding_helper=sentence_service.name_cached_embedding_helper,
df_map_data=sentence_service.df_name_map_data,
)
name_sentence_mapper.predict_input(df_input_data=df_output_data)
fillna_with_space(df_output_data)
except Exception as e:
print(f"Error processing NameSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# SubSubjectAndNameDataMapper
try:
if sentence_service.df_sub_subject_and_name_map_data is not None:
sub_subject_and_name_data_mapper = SubSubjectAndNameDataMapper(
df_map_data=sentence_service.df_sub_subject_and_name_map_data
)
sub_subject_and_name_data_mapper.map_data(df_input_data=df_output_data)
except Exception as e:
print(f"Error processing SubSubjectAndNameDataMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# UnitMapper
try:
if sentence_service.df_unit_map_data is not None:
unit_similarity_mapper = UnitMapper(
cached_embedding_helper=sentence_service.unit_cached_embedding_helper,
df_map_data=sentence_service.df_unit_map_data,
)
unit_map_key_columns = ["単位"]
df_input_unit = df_input_data[unit_map_key_columns].drop_duplicates()
unit_similarity_mapper.predict_input(df_input_data=df_input_unit)
output_unit_data_columns = ["出力_基準単位", "出力_単位類似度", "出力_集計用単位", "出力_標準単位"]
unit_data_mapper = BaseDictionaryMapper(
df_input_unit, unit_map_key_columns, output_unit_data_columns
)
_ = unit_data_mapper.map_data(
df_input_data=df_output_data, input_key_columns=unit_map_key_columns, in_place=True
)
fillna_with_space(df_output_data)
except Exception as e:
print(f"Error processing UnitMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# AbstractSimilarityMapper
try:
if sentence_service.df_abstract_map_data is not None:
abstract_similarity_mapper = AbstractSimilarityMapper(
cached_embedding_helper=sentence_service.abstract_cached_embedding_helper,
df_map_data=sentence_service.df_abstract_map_data,
)
abstract_similarity_mapper.predict_input(df_input_data=df_output_data)
except Exception as e:
print(f"Error processing AbstractSimilarityMapper: {e}")
print(f"DEBUG: Full error traceback:")
traceback.print_exc()
# Don't raise the exception, continue processing
print(f"DEBUG: Continuing without AbstractSimilarityMapper...")
# NameAndAbstractDataMapper
try:
if sentence_service.df_name_and_subject_map_data is not None:
name_and_abstract_mapper = NameAndAbstractDataMapper(
df_map_data=sentence_service.df_name_and_subject_map_data
)
df_output_data["出力_項目名"] = df_output_data["出力_標準名称"]
_ = name_and_abstract_mapper.map_data(df_output_data)
fillna_with_space(df_output_data)
df_output_data["出力_項目名(中科目抜き)"] = df_output_data["出力_項目名"]
except Exception as e:
print(f"Error processing NameAndAbstractDataMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# SubSubjectLocationDataMapper
try:
sub_subject_location_mapper = SubSubjectLocationDataMapper()
sub_subject_location_mapper.map_location(df_output_data)
df_output_data["名称"] = df_output_data["元名称"]
except Exception as e:
print(f"Error processing SubSubjectLocationDataMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Create output columns and ensure they have proper values
# Add ID column if not exists
if "ID" not in df_output_data.columns:
df_output_data.reset_index(drop=False, inplace=True)
df_output_data.rename(columns={"index": "ID"}, inplace=True)
df_output_data["ID"] = df_output_data["ID"] + 1 # Start from 1
# Ensure required columns exist with default values
required_columns = {
"シート名": "",
"行": "",
"科目": "",
"中科目": "",
"分類": "",
"名称": "",
"単位": "",
"摘要": "",
"備考": "",
"出力_科目": "",
"出力_中科目": "",
"出力_項目名": "",
"出力_標準単位": "",
"出力_集計用単位": "",
"出力_確率度": 0.0,
}
for col, default_value in required_columns.items():
if col not in df_output_data.columns:
df_output_data[col] = default_value
# Map output columns to match Excel structure
# 出力_中科目 mapping - use the standard sub-subject from sub-subject mapper
if "出力_中科目" in df_output_data.columns:
df_output_data["出力_中科目"] = df_output_data["出力_中科目"]
elif "出力_基準中科目" in df_output_data.columns:
df_output_data["出力_中科目"] = df_output_data["出力_基準中科目"]
elif "標準中科目" in df_output_data.columns:
df_output_data["出力_中科目"] = df_output_data["標準中科目"]
# 出力_項目名 mapping - use the final item name from name and abstract mapper
if "出力_項目名" in df_output_data.columns:
df_output_data["出力_項目名"] = df_output_data["出力_項目名"]
elif "出力_標準名称" in df_output_data.columns:
df_output_data["出力_項目名"] = df_output_data["出力_標準名称"]
elif "出力_基準名称" in df_output_data.columns:
df_output_data["出力_項目名"] = df_output_data["出力_基準名称"]
# 出力_標準単位 mapping - use unit mapper result
if "出力_標準単位" in df_output_data.columns:
df_output_data["出力_標準単位"] = df_output_data["出力_標準単位"]
# 出力_集計用単位 mapping - use unit mapper result
if "出力_集計用単位" in df_output_data.columns:
df_output_data["出力_集計用単位"] = np.where(
df_output_data["集計単位"] != "",
df_output_data["集計単位"],
df_output_data["出力_集計用単位"]
)
# 出力_確率度 mapping - use the name similarity as main probability
if "出力_名称類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_名称類似度"]
elif "出力_中科目類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_中科目類似度"]
elif "出力_摘要類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_摘要類似度"]
elif "出力_単位類似度" in df_output_data.columns:
df_output_data["出力_確率度"] = df_output_data["出力_単位類似度"]
else:
df_output_data["出力_確率度"] = 0.0
# Fill NaN values and ensure all output columns have proper values
df_output_data = df_output_data.fillna("")
# Debug: Print available columns to see what we have
print(f"Available columns after processing: {list(df_output_data.columns)}")
# Define output columns in exact order as shown in Excel
output_columns = [
"ID",
"シート名",
"行",
"科目",
"中科目",
"分類",
"名称",
"単位",
"摘要",
"備考",
"出力_科目",
"出力_中科目",
"出力_項目名",
"出力_確率度",
"出力_標準単位",
"出力_集計用単位",
]
# Save with utf_8_sig encoding for Japanese Excel compatibility
df_output_data[output_columns].to_csv(
output_file_path, index=False, encoding="utf_8_sig"
)
# Save all caches
sentence_service.save_all_caches()
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
return FileResponse(
path=output_file_path,
filename=f"output_{Path(file.filename).stem}.csv",
media_type="text/csv",
headers={
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
"Content-Type": "application/x-www-form-urlencoded",
},
)
except Exception as e:
print(f"Error processing file: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/embeddings")
async def create_embeddings(
request: EmbeddingRequest,
current_user=Depends(get_current_user_from_token),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Create embeddings for a list of input sentences (requires authentication)
"""
try:
start_time = time.time()
embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(
request.sentences
)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
# Convert numpy array to list for JSON serialization
embeddings_list = embeddings.tolist()
return {"embeddings": embeddings_list}
except Exception as e:
print(f"Error creating embeddings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/predict-raw", response_model=PredictRawResponse)
async def predict_raw(
request: PredictRawRequest,
current_user=Depends(get_current_user_from_token),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Process raw input records and return standardized names (requires authentication)
"""
try:
# Convert input records to DataFrame
records_dict = {
"科目": [],
"中科目": [],
"分類": [],
"名称": [],
"単位": [],
"摘要": [],
"備考": [],
"シート名": [], # Required by BaseNameData but not used
"行": [], # Required by BaseNameData but not used
}
for record in request.records:
records_dict["科目"].append(record.subject)
records_dict["中科目"].append(record.sub_subject)
records_dict["分類"].append(record.name_category)
records_dict["名称"].append(record.name)
records_dict["単位"].append("") # Default empty
records_dict["摘要"].append(record.abstract or "")
records_dict["備考"].append(record.memo or "")
records_dict["シート名"].append("") # Placeholder
records_dict["行"].append("") # Placeholder
df_input_data = pd.DataFrame(records_dict)
# Process data similar to the main predict function
try:
# Subject mapping
if sentence_service.df_subject_map_data is not None:
subject_similarity_mapper = SubjectSimilarityMapper(
cached_embedding_helper=sentence_service.subject_cached_embedding_helper,
df_map_data=sentence_service.df_subject_map_data,
)
list_input_subject = df_input_data["科目"].unique()
df_subject_data = pd.DataFrame({"科目": list_input_subject})
subject_similarity_mapper.predict_input(df_input_data=df_subject_data)
output_subject_map = dict(
zip(df_subject_data["科目"], df_subject_data["出力_科目"])
)
df_input_data["標準科目"] = df_input_data["科目"].map(
output_subject_map
)
df_input_data["出力_科目"] = df_input_data["科目"].map(
output_subject_map
)
else:
df_input_data["標準科目"] = df_input_data["科目"]
df_input_data["出力_科目"] = df_input_data["科目"]
except Exception as e:
print(f"Error processing SubjectSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Name mapping (simplified for raw predict)
if sentence_service.df_name_map_data is not None:
name_sentence_mapper = NameSimilarityMapper(
cached_embedding_helper=sentence_service.name_cached_embedding_helper,
df_map_data=sentence_service.df_name_map_data,
)
name_sentence_mapper.predict_input(df_input_data=df_input_data)
except Exception as e:
print(f"Error processing NameSimilarityMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
# Unit mapping
if sentence_service.df_unit_map_data is not None:
unit_mapper = UnitMapper(
cached_embedding_helper=sentence_service.unit_cached_embedding_helper,
df_map_data=sentence_service.df_unit_map_data,
)
unit_mapper.predict_input(df_input_data=df_input_data)
except Exception as e:
print(f"Error processing UnitMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Ensure required columns exist
for col in [
"確定",
"出力_標準名称",
"出力_名称類似度",
"出力_標準単位",
"出力_単位類似度",
]:
if col not in df_input_data.columns:
if col in ["出力_名称類似度", "出力_単位類似度"]:
df_input_data[col] = 0.0
else:
df_input_data[col] = ""
# Convert results to response format
results = []
for _, row in df_input_data.iterrows():
result = PredictResult(
subject=row["科目"],
sub_subject=row["中科目"],
name_category=row["分類"],
name=row["名称"],
abstract=row["摘要"],
memo=row["備考"],
confirmed=row.get("確定", ""),
standard_subject=row.get("出力_科目", row["科目"]),
standard_name=row.get("出力_標準名称", ""),
similarity_score=float(row.get("出力_名称類似度", 0.0)),
)
results.append(result)
# Save all caches
sentence_service.save_all_caches()
return PredictRawResponse(results=results)
except Exception as e:
print(f"Error processing records: {e}")
raise HTTPException(status_code=500, detail=str(e))