TheVera commited on
Commit
d18e30c
·
verified ·
1 Parent(s): 3b864a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -11
app.py CHANGED
@@ -1,11 +1,16 @@
1
  import uvicorn
2
-
3
  from flask import Flask, request, jsonify
4
  from huggingface_hub import InferenceClient
5
 
6
  app = Flask(__name__)
7
 
8
- API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
 
 
 
 
 
9
 
10
  def format_prompt(message, custom_instructions=None):
11
  prompt = ""
@@ -14,12 +19,26 @@ def format_prompt(message, custom_instructions=None):
14
  prompt += f"[INST] {message} [/INST]"
15
  return prompt
16
 
17
- def Mistral7B(prompt, instructions, api, temperature=0.1, max_new_tokens=2, top_p=0.95, repetition_penalty=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  global API_URL
19
  try:
20
- temperature = float(temperature)
21
- if temperature < 1e-2:
22
- temperature = 1e-2
23
  top_p = float(top_p)
24
 
25
  generate_kwargs = dict(
@@ -33,26 +52,27 @@ def Mistral7B(prompt, instructions, api, temperature=0.1, max_new_tokens=2, top_
33
  custom_instructions = instructions
34
  formatted_prompt = format_prompt(prompt, custom_instructions)
35
 
36
- head = {"Authorization": f"Bearer {api}"}
37
- client = InferenceClient(API_URL, headers=head)
38
  response = client.text_generation(formatted_prompt, **generate_kwargs)
39
  return response
40
  except Exception as e:
41
  return str(e)
42
 
 
43
  @app.route("/generate-text", methods=["POST"])
44
  def generate_text():
45
  data = request.json
46
  prompt = data.get("prompt")
47
  instructions = data.get("instructions")
48
- api_key = data.get("api_key")
49
 
50
  if not prompt or not instructions or not api_key:
51
  return jsonify({"error": "Missing required fields"}), 400
52
 
53
- response = Mistral7B(prompt, instructions, api_key)
 
54
 
55
  return jsonify({"response": response}), 200
56
 
57
  if __name__ == "__main__":
58
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import uvicorn
2
+ import os
3
  from flask import Flask, request, jsonify
4
  from huggingface_hub import InferenceClient
5
 
6
  app = Flask(__name__)
7
 
8
+ # API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1"
9
+ # Use environment variables for configuration
10
+ API_URL = os.getenv("API_URL", "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1")
11
+ # API_KEY = os.getenv("API_KEY", "your_default_api_key") # Default API_KEY can be set here
12
+ API_KEY = os.getenv("API_KEY")
13
+
14
 
15
  def format_prompt(message, custom_instructions=None):
16
  prompt = ""
 
19
  prompt += f"[INST] {message} [/INST]"
20
  return prompt
21
 
22
+ def normalize_text(text):
23
+ # Normalize text to handle different spellings
24
+ replacements = {
25
+ 'summarise': 'summarize',
26
+ 'colour': 'color',
27
+ 'favour': 'favor',
28
+ 'centre': 'center',
29
+ # Add more replacements as needed
30
+ }
31
+
32
+ text = text.lower()
33
+ for british, american in replacements.items():
34
+ text = text.replace(british, american)
35
+
36
+ return text
37
+
38
+ def Mistral7B(prompt, instructions, api_key, temperature=0.1, max_new_tokens=2, top_p=0.95, repetition_penalty=1.0):
39
  global API_URL
40
  try:
41
+ temperature = max(float(temperature), 1e-2)
 
 
42
  top_p = float(top_p)
43
 
44
  generate_kwargs = dict(
 
52
  custom_instructions = instructions
53
  formatted_prompt = format_prompt(prompt, custom_instructions)
54
 
55
+ client = InferenceClient(api_url=API_URL, token=api_key)
 
56
  response = client.text_generation(formatted_prompt, **generate_kwargs)
57
  return response
58
  except Exception as e:
59
  return str(e)
60
 
61
+
62
  @app.route("/generate-text", methods=["POST"])
63
  def generate_text():
64
  data = request.json
65
  prompt = data.get("prompt")
66
  instructions = data.get("instructions")
67
+ api_key = data.get("api_key", API_KEY) # Use provided API key or default to environment variable
68
 
69
  if not prompt or not instructions or not api_key:
70
  return jsonify({"error": "Missing required fields"}), 400
71
 
72
+ normalized_prompt = normalize_text(prompt)
73
+ response = Mistral7B(normalized_prompt, instructions, api_key)
74
 
75
  return jsonify({"response": response}), 200
76
 
77
  if __name__ == "__main__":
78
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 8000)))