|
import json |
|
import argparse |
|
import torch |
|
import os |
|
import random |
|
import numpy as np |
|
import requests |
|
import logging |
|
import math |
|
import copy |
|
import string |
|
|
|
from tqdm import tqdm |
|
from time import time |
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from tornado.wsgi import WSGIContainer |
|
from tornado.httpserver import HTTPServer |
|
from tornado.ioloop import IOLoop |
|
|
|
from simcse import SimCSE |
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', |
|
level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
def run_simcse_demo(port, args): |
|
app = Flask(__name__, static_folder='./static') |
|
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False |
|
CORS(app) |
|
|
|
sentence_path = os.path.join(args.sentences_dir, args.example_sentences) |
|
query_path = os.path.join(args.sentences_dir, args.example_query) |
|
embedder = SimCSE(args.model_name_or_path) |
|
embedder.build_index(sentence_path) |
|
@app.route('/') |
|
def index(): |
|
return app.send_static_file('index.html') |
|
|
|
@app.route('/api', methods=['GET']) |
|
def api(): |
|
query = request.args['query'] |
|
top_k = int(request.args['topk']) |
|
threshold = float(request.args['threshold']) |
|
start = time() |
|
results = embedder.search(query, top_k=top_k, threshold=threshold) |
|
ret = [] |
|
out = {} |
|
for sentence, score in results: |
|
ret.append({"sentence": sentence, "score": score}) |
|
span = time() - start |
|
out['ret'] = ret |
|
out['time'] = "{:.4f}".format(span) |
|
return jsonify(out) |
|
|
|
@app.route('/files/<path:path>') |
|
def static_files(path): |
|
return app.send_static_file('files/' + path) |
|
|
|
@app.route('/get_examples', methods=['GET']) |
|
def get_examples(): |
|
with open(query_path, 'r') as fp: |
|
examples = [line.strip() for line in fp.readlines()] |
|
return jsonify(examples) |
|
|
|
addr = args.ip + ":" + args.port |
|
logger.info(f'Starting Index server at {addr}') |
|
http_server = HTTPServer(WSGIContainer(app)) |
|
http_server.listen(port) |
|
IOLoop.instance().start() |
|
|
|
if __name__=="__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model_name_or_path', default=None, type=str) |
|
parser.add_argument('--device', default='cpu', type=str) |
|
parser.add_argument('--sentences_dir', default=None, type=str) |
|
parser.add_argument('--example_query', default=None, type=str) |
|
parser.add_argument('--example_sentences', default=None, type=str) |
|
parser.add_argument('--port', default='8888', type=str) |
|
parser.add_argument('--ip', default='http://127.0.0.1') |
|
parser.add_argument('--load_light', default=False, action='store_true') |
|
args = parser.parse_args() |
|
|
|
run_simcse_demo(args.port, args) |