meisaicheck-api / routes /predict.py
vumichien's picture
token store
6830bc7
import os
import time
import shutil
from pathlib import Path
from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
from fastapi.responses import FileResponse
from custom_auth import get_current_user_from_token
from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
from data_lib.input_name_data import InputNameData
from data_lib.base_name_data import COL_NAME_SENTENCE
from mapping_lib.subject_mapper import SubjectMapper
from mapping_lib.name_mapper import NameMapper
from config import UPLOAD_DIR, OUTPUT_DIR
from models import (
EmbeddingRequest,
PredictRawRequest,
PredictRawResponse,
PredictRecord,
PredictResult,
)
import pandas as pd
import traceback
router = APIRouter()
@router.post("/predict")
async def predict(
current_user=Depends(get_current_user_from_token),
file: UploadFile = File(...),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Process an input CSV file and return standardized names (requires authentication)
"""
if not file.filename.endswith(".csv"):
raise HTTPException(status_code=400, detail="Only CSV files are supported")
# Save uploaded file
timestamp = int(time.time())
input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv")
output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv")
try:
with open(input_file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
finally:
file.file.close()
try:
# Process input data
start_time = time.time()
try:
inputData = InputNameData()
inputData.load_data_from_csv(input_file_path)
except Exception as e:
print(f"Error processing load data: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
subject_mapper = SubjectMapper(
sentence_transformer_helper=sentence_service.sentenceTransformerHelper,
dic_subject_map=sentence_service.dic_standard_subject,
similarity_threshold=0.9,
)
dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe)
except Exception as e:
print(f"Error processing SubjectMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
inputData.dic_standard_subject = dic_subject_map
inputData.process_data()
except Exception as e:
print(f"Error processing inputData process_data: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Map standard names
try:
nameMapper = NameMapper(
sentence_service.sentenceTransformerHelper,
sentence_service.standardNameMapData,
top_count=3
)
df_predicted = nameMapper.predict(inputData)
except Exception as e:
print(f"Error mapping standard names: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# Create output dataframe and save to CSV
# column_to_keep = ['ファイル名', 'シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考']
column_to_keep = ['シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考', '確定']
output_df = inputData.dataframe[column_to_keep].copy()
output_df.reset_index(drop=False, inplace=True)
output_df.loc[:, "出力_科目"] = df_predicted["標準科目"]
output_df.loc[:, "出力_項目名"] = df_predicted["標準項目名"]
output_df.loc[:, "出力_確率度"] = df_predicted["基準名称類似度"]
# Save with utf_8_sig encoding for Japanese Excel compatibility
output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig")
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
return FileResponse(
path=output_file_path,
filename=f"output_{Path(file.filename).stem}.csv",
media_type="text/csv",
headers={
"Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"',
"Content-Type": "application/x-www-form-urlencoded",
},
)
except Exception as e:
print(f"Error processing file: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/embeddings")
async def create_embeddings(
request: EmbeddingRequest,
current_user=Depends(get_current_user_from_token),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Create embeddings for a list of input sentences (requires authentication)
"""
try:
start_time = time.time()
embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(
request.sentences
)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time: {execution_time} seconds")
# Convert numpy array to list for JSON serialization
embeddings_list = embeddings.tolist()
return {"embeddings": embeddings_list}
except Exception as e:
print(f"Error creating embeddings: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/predict-raw", response_model=PredictRawResponse)
async def predict_raw(
request: PredictRawRequest,
current_user=Depends(get_current_user_from_token),
sentence_service: SentenceTransformerService = Depends(
lambda: sentence_transformer_service
),
):
"""
Process raw input records and return standardized names (requires authentication)
"""
try:
# Convert input records to DataFrame
records_dict = {
"科目": [],
"中科目": [],
"分類": [],
"名称": [],
"摘要": [],
"備考": [],
"シート名": [], # Required by BaseNameData but not used
"行": [], # Required by BaseNameData but not used
}
for record in request.records:
records_dict["科目"].append(record.subject)
records_dict["中科目"].append(record.sub_subject)
records_dict["分類"].append(record.name_category)
records_dict["名称"].append(record.name)
records_dict["摘要"].append(record.abstract or "")
records_dict["備考"].append(record.memo or "")
records_dict["シート名"].append("") # Placeholder
records_dict["行"].append("") # Placeholder
df = pd.DataFrame(records_dict)
# Process input data
try:
inputData = InputNameData(sentence_service.dic_standard_subject)
# Use _add_raw_data instead of direct assignment
inputData._add_raw_data(df)
except Exception as e:
print(f"Error processing input data: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
subject_mapper = SubjectMapper(
sentence_transformer_helper=sentence_service.sentenceTransformerHelper,
dic_subject_map=sentence_service.dic_standard_subject,
similarity_threshold=0.9,
)
dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe)
except Exception as e:
print(f"Error processing SubjectMapper: {e}")
raise HTTPException(status_code=500, detail=str(e))
try:
inputData.dic_standard_subject = dic_subject_map
inputData.process_data()
except Exception as e:
print(f"Error processing inputData process_data: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Map standard names
try:
nameMapper = NameMapper(
sentence_service.sentenceTransformerHelper,
sentence_service.standardNameMapData,
top_count=3
)
df_predicted = nameMapper.predict(inputData)
except Exception as e:
print(f"Error mapping standard names: {e}")
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
important_columns = ['確定', '標準科目', '標準項目名', '基準名称類似度']
for column in important_columns:
if column not in df_predicted.columns:
if column != '基準名称類似度':
df_predicted[column] = ""
inputData.dataframe[column] = ""
else:
df_predicted[column] = 0
inputData.dataframe[column] = 0
column_to_keep = ['シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考', '確定']
output_df = inputData.dataframe[column_to_keep].copy()
output_df.reset_index(drop=False, inplace=True)
output_df.loc[:, "出力_科目"] = df_predicted["標準科目"]
output_df.loc[:, "出力_項目名"] = df_predicted["標準項目名"]
output_df.loc[:, "出力_確率度"] = df_predicted["基準名称類似度"]
# Convert results to response format
results = []
for _, row in output_df.iterrows():
result = PredictResult(
subject=row["科目"],
sub_subject=row["中科目"],
name_category=row["分類"],
name=row["名称"],
abstract=row["摘要"],
memo=row["備考"],
confirmed=row["確定"],
standard_subject=row["出力_科目"],
standard_name=row["出力_項目名"],
similarity_score=float(row["出力_確率度"]),
)
results.append(result)
return PredictRawResponse(results=results)
except Exception as e:
print(f"Error processing records: {e}")
raise HTTPException(status_code=500, detail=str(e))