AFischer1985 commited on
Commit
66cb21c
·
verified ·
1 Parent(s): a2398ee

Update to Jina-embeddings

Browse files
Files changed (1) hide show
  1. run.py +57 -37
run.py CHANGED
@@ -2,10 +2,9 @@
2
  # Title: BERUFENET.AI
3
  # Author: Andreas Fischer
4
  # Date: January 4th, 2024
5
- # Last update: February 8th, 2024
6
  #############################################################################
7
 
8
- import os
9
  dbPath="/home/af/Schreibtisch/Code/gradio/BERUFENET/db"
10
  if(os.path.exists(dbPath)==False): dbPath="/home/user/app/db"
11
 
@@ -15,42 +14,52 @@ print(dbPath)
15
  #-----------
16
 
17
  import chromadb
18
- #client = chromadb.Client()
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  path=dbPath
20
  client = chromadb.PersistentClient(path=path)
21
  print(client.heartbeat())
22
  print(client.get_version())
23
  print(client.list_collections())
24
  from chromadb.utils import embedding_functions
25
- default_ef = embedding_functions.DefaultEmbeddingFunction()
26
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
27
  #instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
 
 
28
  print(str(client.list_collections()))
29
 
 
30
  global collection
31
- if("name=BerufenetDB1" in str(client.list_collections())): #(False):
32
  print("BerufenetDB1 found!")
33
- collection = client.get_collection(name="BerufenetDB1", embedding_function=sentence_transformer_ef)
34
 
35
  print("Database ready!")
36
  print(collection.count())
37
 
38
 
39
- # Model
40
- #-------
41
-
42
- from huggingface_hub import InferenceClient
43
- import gradio as gr
44
-
45
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
46
-
47
-
48
  # Gradio-GUI
49
  #------------
50
 
 
51
  import gradio as gr
52
  import json
53
 
 
 
54
  def format_prompt(message, history):
55
  prompt = "" #"<s>"
56
  #for user_prompt, bot_response in history:
@@ -59,26 +68,17 @@ def format_prompt(message, history):
59
  prompt += f"[INST] {message} [/INST]"
60
  return prompt
61
 
62
- def response(
63
- prompt, history, temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
64
- ):
65
- temperature = float(temperature)
66
- if temperature < 1e-2: temperature = 1e-2
67
- top_p = float(top_p)
68
- generate_kwargs = dict(
69
- temperature=temperature,
70
- max_new_tokens=max_new_tokens,
71
- top_p=top_p,
72
- repetition_penalty=repetition_penalty,
73
- do_sample=True,
74
- seed=42,
75
- )
76
  addon=""
77
  results=collection.query(
78
  query_texts=[prompt],
79
- n_results=5,
80
- #where={"source": "google-docs"}
81
- #where_document={"$contains":"search_string"}
82
  )
83
  dists=["<br><small>(relevance: "+str(round((1-d)*100)/100)+";" for d in results['distances'][0]]
84
  sources=["source: "+s["source"]+")</small>" for s in results['metadatas'][0]]
@@ -89,14 +89,34 @@ def response(
89
  if(len(results)>1):
90
  addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
91
  system="Du bist ein deutschsprachiges KI-basiertes Assistenzsystem, das zu jedem Anliegen möglichst geeignete Berufe empfiehlt."+addon+"\n\nUser-Anliegen:"
92
- formatted_prompt = format_prompt(system+"\n"+prompt, history)
93
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
94
  output = ""
95
- for response in stream:
 
 
 
96
  output += response.token.text
97
  yield output
 
 
 
 
 
 
98
  output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
99
  yield output
100
 
101
- gr.ChatInterface(response, chatbot=gr.Chatbot(value=[[None,"Herzlich willkommen! Ich bin ein KI-basiertes Assistenzsystem, das für jede Anfrage die am besten passenden Berufe empfiehlt.<br>Erzähle mir, was du gerne tust!"]],render_markdown=True),title="German BERUFENET-RAG-Interface to the Hugging Face Hub").queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
 
 
 
 
 
 
 
 
 
 
102
  print("Interface up and running!")
 
 
 
2
  # Title: BERUFENET.AI
3
  # Author: Andreas Fischer
4
  # Date: January 4th, 2024
5
+ # Last update: October 15th, 2024
6
  #############################################################################
7
 
 
8
  dbPath="/home/af/Schreibtisch/Code/gradio/BERUFENET/db"
9
  if(os.path.exists(dbPath)==False): dbPath="/home/user/app/db"
10
 
 
14
  #-----------
15
 
16
  import chromadb
17
+ from chromadb import Documents, EmbeddingFunction, Embeddings
18
+ import torch # chromaDB
19
+ from transformers import AutoTokenizer, AutoModel # chromaDB
20
+ jina = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-de', trust_remote_code=True, torch_dtype=torch.bfloat16)
21
+ #jira.save_pretrained("jinaai_jina-embeddings-v2-base-de")
22
+ device='cuda:0' if torch.cuda.is_available() else 'cpu'
23
+ jina.to(device) #cuda:0
24
+ print(device)
25
+
26
+ class JinaEmbeddingFunction(EmbeddingFunction):
27
+ def __call__(self, input: Documents) -> Embeddings:
28
+ embeddings = jina.encode(input) #max_length=2048
29
+ return(embeddings.tolist())
30
+
31
  path=dbPath
32
  client = chromadb.PersistentClient(path=path)
33
  print(client.heartbeat())
34
  print(client.get_version())
35
  print(client.list_collections())
36
  from chromadb.utils import embedding_functions
37
+ #default_ef = embedding_functions.DefaultEmbeddingFunction()
38
+ #sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
39
  #instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
40
+ jina_ef=JinaEmbeddingFunction()
41
+ embeddingFunction=jina_ef
42
  print(str(client.list_collections()))
43
 
44
+
45
  global collection
46
+ if("name=BerufenetDB1" in str(client.list_collections())):
47
  print("BerufenetDB1 found!")
48
+ collection = client.get_collection(name=, embedding_function=embeddingFunction)
49
 
50
  print("Database ready!")
51
  print(collection.count())
52
 
53
 
 
 
 
 
 
 
 
 
 
54
  # Gradio-GUI
55
  #------------
56
 
57
+ from huggingface_hub import InferenceClient
58
  import gradio as gr
59
  import json
60
 
61
+ myModel="mistralai/Mixtral-8x7B-Instruct-v0.1"
62
+
63
  def format_prompt(message, history):
64
  prompt = "" #"<s>"
65
  #for user_prompt, bot_response in history:
 
68
  prompt += f"[INST] {message} [/INST]"
69
  return prompt
70
 
71
+ def response(prompt, history, hfToken):
72
+ inferenceClient=""
73
+ if(hfToken.startswith("hf_")): # use HF-hub with custom token if token is provided
74
+ inferenceClient = InferenceClient(model=myModel, token=hfToken)
75
+ else:
76
+ inferenceClient = InferenceClient(myModel)
77
+ generate_kwargs = dict(temperature=float(0.9), max_new_tokens=500, top_p=0.95, repetition_penalty=1.0, do_sample=True, seed=42)
 
 
 
 
 
 
 
78
  addon=""
79
  results=collection.query(
80
  query_texts=[prompt],
81
+ n_results=5
 
 
82
  )
83
  dists=["<br><small>(relevance: "+str(round((1-d)*100)/100)+";" for d in results['distances'][0]]
84
  sources=["source: "+s["source"]+")</small>" for s in results['metadatas'][0]]
 
89
  if(len(results)>1):
90
  addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
91
  system="Du bist ein deutschsprachiges KI-basiertes Assistenzsystem, das zu jedem Anliegen möglichst geeignete Berufe empfiehlt."+addon+"\n\nUser-Anliegen:"
92
+ formatted_prompt = format_prompt(system+"\n"+prompt, history)
 
93
  output = ""
94
+ print(""+str(inferenceClient))
95
+ try:
96
+ stream = inferenceClient.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
97
+ for response in stream:
98
  output += response.token.text
99
  yield output
100
+ except Exception as e:
101
+ output = "Für weitere Antworten von der KI gebe bitte einen gültigen HuggingFace-Token an."
102
+ if(len(combination)>0):
103
+ output += "\nBis dahin helfen dir hoffentlich die folgenden Quellen weiter:"
104
+ yield output
105
+ print(str(e))
106
  output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
107
  yield output
108
 
109
+
110
+ gr.ChatInterface(
111
+ response,
112
+ chatbot=gr.Chatbot(value=[[None,"Herzlich willkommen! Ich bin ein KI-basiertes Assistenzsystem, das für jede Anfrage die am besten passenden Berufe empfiehlt.<br>Erzähle mir, was du gerne tust!"]],render_markdown=True),
113
+ title="BERUFENET.AI (Jina-Embeddings)",
114
+ additional_inputs=[
115
+ gr.Textbox(
116
+ value="",
117
+ label="HF_token"),
118
+ ]
119
+ ).queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
120
  print("Interface up and running!")
121
+
122
+