santanavagner's picture
Update app.py
a5a17e2 verified
#!/usr/bin/env python
# coding: utf-8
# Copyright 2021, IBM Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Flask API app and routes.
"""
__author__ = "Vagner Santana, Melina Alberio, Cassia Sanctos and Tiago Machado"
__copyright__ = "IBM Corporation 2024"
__credits__ = ["Vagner Santana, Melina Alberio, Cassia Sanctos, Tiago Machado"]
__license__ = "Apache 2.0"
__version__ = "0.0.1"
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS, cross_origin
from flask_restful import Resource, Api, reqparse
import control.recommendation_handler as recommendation_handler
from helpers import get_credentials, authenticate_api, save_model
import config as cfg
import logging
import uuid
import json
import os
import requests
app = Flask(__name__, static_folder='static')
# configure logging
logging.basicConfig(
filename='app.log', # Log file name
level=logging.INFO, # Log level (INFO, DEBUG, WARNING, ERROR, CRITICAL)
format='%(asctime)s - %(levelname)s - %(message)s' # Log message format
)
# access the app's logger
logger = app.logger
# create user id
id = str(uuid.uuid4())
# swagger configs
app.register_blueprint(cfg.SWAGGER_BLUEPRINT, url_prefix = cfg.SWAGGER_URL)
FRONT_LOG_FILE = 'front_log.json'
@app.route("/")
def index():
user_ip = request.remote_addr
logger.info(f'USER {user_ip} - ID {id} - started the app')
return app.send_static_file('demo/index.html')
@app.route("/recommend", methods=['GET'])
@cross_origin()
def recommend():
user_ip = request.remote_addr
hf_token, hf_url = get_credentials.get_credentials()
api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
prompt_json = recommendation_handler.populate_json()
args = request.args
prompt = args.get("prompt")
recommendation_json = recommendation_handler.recommend_prompt(prompt, prompt_json,
api_url, headers)
logger.info(f'USER - {user_ip} - ID {id} - accessed recommend route')
logger.info(f'RECOMMEND ROUTE - request: {prompt} response: {recommendation_json}')
return recommendation_json
@app.route("/get_thresholds", methods=['GET'])
@cross_origin()
def get_thresholds():
hf_token, hf_url = get_credentials.get_credentials()
api_url, headers = authenticate_api.authenticate_api(hf_token, hf_url)
prompt_json = recommendation_handler.populate_json()
model_id = 'sentence-transformers/all-minilm-l6-v2'
args = request.args
#print("args list = ", args)
prompt = args.get("prompt")
thresholds_json = recommendation_handler.get_thresholds(prompt, prompt_json, api_url,
headers, model_id)
return thresholds_json
@app.route("/recommend_local", methods=['GET'])
@cross_origin()
def recommend_local():
model_id, model_path = save_model.save_model()
prompt_json = recommendation_handler.populate_json()
args = request.args
print("args list = ", args)
prompt = args.get("prompt")
local_recommendation_json = recommendation_handler.recommend_local(prompt, prompt_json,
model_id, model_path)
return local_recommendation_json
@app.route("/log", methods=['POST'])
@cross_origin()
def log():
f_path = 'static/demo/log/'
new_data = request.get_json()
try:
with open(f_path+FRONT_LOG_FILE, 'r') as f:
existing_data = json.load(f)
except FileNotFoundError:
existing_data = []
existing_data.update(new_data)
#log_data = request.json
with open(f_path+FRONT_LOG_FILE, 'w') as f:
json.dump(existing_data, f)
return jsonify({'message': 'Data added successfully', 'data': existing_data}), 201
@app.route("/demo_inference", methods=['GET'])
@cross_origin()
def demo_inference():
args = request.args
# model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model_id = args.get('model_id', default="meta-llama/Llama-4-Scout-17B-16E-Instruct")
temperature = args.get('temperature', default=0.5)
max_new_tokens = args.get('max_new_tokens', default=1000)
hf_token, hf_url = get_credentials.get_credentials()
prompt = args.get('prompt')
API_URL = "https://router.huggingface.co/together/v1/chat/completions"
headers = {
"Authorization": f"Bearer {hf_token}",
}
response = requests.post(
API_URL,
headers=headers,
json={
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
]
}
],
"model": model_id,
'temperature': temperature,
'max_new_tokens': max_new_tokens,
}
)
try:
response = response.json()["choices"][0]["message"]
response.update({
'model_id': model_id,
'temperature': temperature,
'max_new_tokens': max_new_tokens,
})
return response
except:
return response.text, response.status_code
if __name__=='__main__':
debug_mode = os.getenv('FLASK_DEBUG', 'True').lower() in ['true', '1', 't']
app.run(host='0.0.0.0', port='7860', debug=debug_mode)