File size: 2,367 Bytes
d18e30c
761f569
3b864a6
 
 
 
 
761f569
 
d18e30c
761f569
 
3b864a6
 
 
 
 
 
 
 
d18e30c
 
 
5861911
 
 
 
d18e30c
 
 
 
 
 
5861911
d18e30c
 
761f569
eb6d5d0
3b864a6
d18e30c
3b864a6
 
 
 
 
 
 
 
 
 
 
 
 
761f569
 
3b864a6
 
 
 
2ef6724
3b864a6
 
 
 
e6e4784
3b864a6
 
 
 
d18e30c
 
8f4aea3
3b864a6
 
 
527a5fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import uvicorn
from flask import Flask, request, jsonify
from huggingface_hub import InferenceClient

app = Flask(__name__)

# Fixed API URL
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"

# Use environment variable for the API key
API_KEY = os.getenv("API_KEY")

def format_prompt(message, custom_instructions=None):
    prompt = ""
    if custom_instructions:
        prompt += f"[INST] {custom_instructions} [/INST]"
    prompt += f"[INST] {message} [/INST]"
    return prompt

def normalize_text(text):
    # Normalize text to handle different spellings
    replacements = {
        'summarise': 'Summarize',
        'colour': 'Color',
        'favour': 'Favor',
        'centre': 'Center',
        # Add more replacements as needed
    }
    
    text = text.lower()
    for british, american in replacements.items():
        text = text.replace(british, american)
        text = text.capitalize()
    
    return text

def Mistral7B(prompt, instructions, api_key, temperature=0.2, max_new_tokens=18, top_p=0.9, repetition_penalty=1.0):
    try:
        temperature = max(float(temperature), 1e-2)
        top_p = float(top_p)

        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            seed=69,
        )
        custom_instructions = instructions
        formatted_prompt = format_prompt(prompt, custom_instructions)

        client = InferenceClient(token=api_key)
        response = client.text_generation(formatted_prompt, model=MODEL_ID, **generate_kwargs)
        return response
    except Exception as e:
        return str(e)

@app.route("/generate_text", methods=["POST"])
def generate_text():
    data = request.json
    prompt = data.get("prompt")
    instructions = data.get("instructions")
    api_key = API_KEY  # Use the API key from the environment variable

    if not prompt or not instructions or not api_key:
        return jsonify({"error": "Missing required fields"}), 400

    normalized_prompt = normalize_text(prompt)
    response = Mistral7B(normalized_prompt, instructions, api_key)
    print(response)
    return jsonify({"response": response}), 200

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))