Spaces:
Running
Running
import os | |
import uvicorn | |
from flask import Flask, request, jsonify | |
from huggingface_hub import InferenceClient | |
app = Flask(__name__) | |
# Fixed API URL | |
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
# Use environment variable for the API key | |
API_KEY = os.getenv("API_KEY") | |
def format_prompt(message, custom_instructions=None): | |
prompt = "" | |
if custom_instructions: | |
prompt += f"[INST] {custom_instructions} [/INST]" | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def normalize_text(text): | |
# Normalize text to handle different spellings | |
replacements = { | |
'summarise': 'Summarize', | |
'colour': 'Color', | |
'favour': 'Favor', | |
'centre': 'Center', | |
# Add more replacements as needed | |
} | |
text = text.lower() | |
for british, american in replacements.items(): | |
text = text.replace(british, american) | |
text = text.capitalize() | |
return text | |
def Mistral7B(prompt, instructions, api_key, temperature=0.2, max_new_tokens=18, top_p=0.9, repetition_penalty=1.0): | |
try: | |
temperature = max(float(temperature), 1e-2) | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=69, | |
) | |
custom_instructions = instructions | |
formatted_prompt = format_prompt(prompt, custom_instructions) | |
client = InferenceClient(token=api_key) | |
response = client.text_generation(formatted_prompt, model=MODEL_ID, **generate_kwargs) | |
return response | |
except Exception as e: | |
return str(e) | |
@app.route("/generate_text", methods=["POST"]) | |
def generate_text(): | |
data = request.json | |
prompt = data.get("prompt") | |
instructions = data.get("instructions") | |
api_key = API_KEY # Use the API key from the environment variable | |
if not prompt or not instructions or not api_key: | |
return jsonify({"error": "Missing required fields"}), 400 | |
normalized_prompt = normalize_text(prompt) | |
response = Mistral7B(normalized_prompt, instructions, api_key) | |
print(response) | |
return jsonify({"response": response}), 200 | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) | |