import json import argparse import torch import os import random import numpy as np import requests import logging import math import copy import string from tqdm import tqdm from time import time from flask import Flask, request, jsonify from flask_cors import CORS from tornado.wsgi import WSGIContainer from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop from simcse import SimCSE logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) def run_simcse_demo(port, args): app = Flask(__name__, static_folder='./static') app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False CORS(app) sentence_path = os.path.join(args.sentences_dir, args.example_sentences) query_path = os.path.join(args.sentences_dir, args.example_query) embedder = SimCSE(args.model_name_or_path) embedder.build_index(sentence_path) @app.route('/') def index(): return app.send_static_file('index.html') @app.route('/api', methods=['GET']) def api(): query = request.args['query'] top_k = int(request.args['topk']) threshold = float(request.args['threshold']) start = time() results = embedder.search(query, top_k=top_k, threshold=threshold) ret = [] out = {} for sentence, score in results: ret.append({"sentence": sentence, "score": score}) span = time() - start out['ret'] = ret out['time'] = "{:.4f}".format(span) return jsonify(out) @app.route('/files/') def static_files(path): return app.send_static_file('files/' + path) @app.route('/get_examples', methods=['GET']) def get_examples(): with open(query_path, 'r') as fp: examples = [line.strip() for line in fp.readlines()] return jsonify(examples) addr = args.ip + ":" + args.port logger.info(f'Starting Index server at {addr}') http_server = HTTPServer(WSGIContainer(app)) http_server.listen(port) IOLoop.instance().start() if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument('--model_name_or_path', default=None, type=str) parser.add_argument('--device', default='cpu', type=str) parser.add_argument('--sentences_dir', default=None, type=str) parser.add_argument('--example_query', default=None, type=str) parser.add_argument('--example_sentences', default=None, type=str) parser.add_argument('--port', default='8888', type=str) parser.add_argument('--ip', default='http://127.0.0.1') parser.add_argument('--load_light', default=False, action='store_true') args = parser.parse_args() run_simcse_demo(args.port, args)