Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from datetime import datetime | |
import json | |
import os | |
from pathlib import Path | |
import re | |
import sys | |
import time | |
import tempfile | |
from zoneinfo import ZoneInfo # Python 3.9+ 自带,无需安装 | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../")) | |
from google import genai | |
from google.genai import types | |
from project_settings import environment, project_path | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model_name", | |
# default="gemini-2.5-pro", # The model does not support setting thinking_budget to 0. | |
default="gemini-2.5-flash", | |
# default="gemini-2.5-flash-lite-preview-06-17", | |
# default="llama-4-maverick-17b-128e-instruct-maas", | |
# default="llama-4-scout-17b-16e-instruct-maas", | |
type=str | |
) | |
parser.add_argument( | |
"--eval_dataset_name", | |
default="agent-bingoplus-ph-25-summary.jsonl", | |
type=str | |
) | |
parser.add_argument( | |
"--eval_dataset_dir", | |
default=(project_path / "data/dataset").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--eval_data_dir", | |
default=(project_path / "data/eval_data").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--client", | |
default="shenzhen_sase", | |
type=str | |
) | |
parser.add_argument( | |
"--service", | |
default="google_potent_veld_462405_t3", | |
type=str | |
) | |
parser.add_argument( | |
"--create_time_str", | |
default="null", | |
# default="20250731_162116", | |
type=str | |
) | |
parser.add_argument( | |
"--interval", | |
default=1, | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = get_args() | |
service = environment.get(args.service, dtype=json.loads) | |
project_id = service["project_id"] | |
google_application_credentials = Path(tempfile.gettempdir()) / f"llm_eval_system/{project_id}.json" | |
google_application_credentials.parent.mkdir(parents=True, exist_ok=True) | |
with open(google_application_credentials.as_posix(), "w", encoding="utf-8") as f: | |
content = json.dumps(service, ensure_ascii=False, indent=4) | |
f.write(f"{content}\n") | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = google_application_credentials.as_posix() | |
eval_dataset_dir = Path(args.eval_dataset_dir) | |
eval_dataset_dir.mkdir(parents=True, exist_ok=True) | |
eval_data_dir = Path(args.eval_data_dir) | |
eval_data_dir.mkdir(parents=True, exist_ok=True) | |
if args.create_time_str == "null": | |
tz = ZoneInfo("Asia/Shanghai") | |
now = datetime.now(tz) | |
create_time_str = now.strftime("%Y%m%d_%H%M%S") | |
# create_time_str = "20250729-interval-5" | |
else: | |
create_time_str = args.create_time_str | |
eval_dataset = eval_dataset_dir / args.eval_dataset_name | |
output_file = eval_data_dir / f"gemini_google/google/{args.model_name}/{args.client}/{args.service}/{create_time_str}/{args.eval_dataset_name}" | |
output_file.parent.mkdir(parents=True, exist_ok=True) | |
client = genai.Client( | |
vertexai=True, | |
project=project_id, | |
# location="global", | |
location="us-east5", | |
) | |
generate_content_config = types.GenerateContentConfig( | |
top_p=0.95, | |
temperature=0.6, | |
# max_output_tokens=1, | |
response_modalities=["TEXT"], | |
thinking_config=types.ThinkingConfig( | |
thinking_budget=0 | |
) | |
) | |
total = 0 | |
total_score = 0 | |
# finished | |
finished_idx_set = set() | |
if os.path.exists(output_file.as_posix()): | |
with open(output_file.as_posix(), "r", encoding="utf-8") as f: | |
for row in f: | |
row = json.loads(row) | |
idx = row["idx"] | |
total = row["total"] | |
total_score = row["total_score"] | |
finished_idx_set.add(idx) | |
print(f"finished count: {len(finished_idx_set)}") | |
with open(eval_dataset.as_posix(), "r", encoding="utf-8") as fin, open(output_file.as_posix(), "a+", encoding="utf-8") as fout: | |
for row in fin: | |
row = json.loads(row) | |
idx = row["idx"] | |
system_prompt: str = row["system_prompt"] | |
user_prompt: str = row["user_prompt"] | |
response = row["response"] | |
if idx in finished_idx_set: | |
continue | |
finished_idx_set.add(idx) | |
contents = [ | |
types.Content( | |
role="model", | |
parts=[ | |
types.Part.from_text(text=system_prompt) | |
] | |
), | |
types.Content( | |
role="user", | |
parts=[ | |
types.Part.from_text(text=user_prompt) | |
] | |
) | |
] | |
time.sleep(args.interval) | |
print(f"sleep: {args.interval}") | |
time_begin = time.time() | |
llm_response: types.GenerateContentResponse = client.models.generate_content( | |
model=args.model_name, | |
contents=contents, | |
config=generate_content_config, | |
) | |
time_cost = time.time() - time_begin | |
print(f"time_cost: {time_cost}") | |
try: | |
prediction = llm_response.candidates[0].content.parts[0].text | |
except TypeError as e: | |
print(f"request failed, error type: {type(e)}, error text: {str(e)}") | |
continue | |
response_ = json.loads(response) | |
response_tag_name_list = response_["tag_name_list"] | |
# print(response_tag_name_list) | |
if prediction.startswith("```json") and prediction.endswith("```"): | |
prediction_ = prediction[7:-3] | |
else: | |
prediction_ = prediction | |
prediction_tag_name_list = list() | |
try: | |
prediction_ = json.loads(prediction_) | |
prediction_tag_name_list = prediction_["tag_name_list"] | |
except json.JSONDecodeError: | |
pass | |
# print(prediction_tag_name_list) | |
# recall | |
recall_count = 0 | |
for tag in response_tag_name_list: | |
if tag in prediction_tag_name_list: | |
recall_count += 1 | |
recall = recall_count / (len(response_tag_name_list) + 1e-7) | |
# precision | |
precision_count = 0 | |
for tag in prediction_tag_name_list: | |
if tag in response_tag_name_list: | |
precision_count += 1 | |
precision = precision_count / (len(prediction_tag_name_list) + 1e-7) | |
# f1 | |
f1 = 2 * (recall * precision) / (recall + precision + 1e-7) | |
total += 1 | |
total_score += f1 | |
score = total_score / total | |
row_ = { | |
"idx": idx, | |
"system_prompt": system_prompt, | |
"user_prompt": user_prompt, | |
"response": response, | |
"prediction": prediction, | |
"recall": recall, | |
"precision": precision, | |
"f1": f1, | |
"total": total, | |
"total_score": total_score, | |
"score": score, | |
"time_cost": time_cost, | |
} | |
row_ = json.dumps(row_, ensure_ascii=False) | |
fout.write(f"{row_}\n") | |
fout.flush() | |
return | |
if __name__ == "__main__": | |
main() | |