File size: 3,828 Bytes
ccf0698
 
 
 
dc675f7
 
0dc48b3
 
 
 
 
3abd747
 
ab3ebc8
dc675f7
ab3ebc8
dc675f7
d339146
 
ab3ebc8
 
 
 
 
 
 
 
 
 
 
 
 
ccf0698
 
 
8ac76ef
ccf0698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f731e0
ccf0698
f86a683
 
ccf0698
 
dc5f58b
 
ccf0698
dc5f58b
7a8d600
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import pipeline
from huggingface_hub import InferenceClient
import os

system_messages = { "STRICT": """You are a chatbot evaluating github repositories, their python codes and corresponding readme files.
                   Keep in mind, that the code you are provided is only one of many files in the repository. 
                   Strictly answer the questions with "Yes" or "No". 
                   Don't use any punctuation either.""",
                    "HELP": """You are a chatbot evaluating github repositories, their python codes and corresponding readme files. 
                    Please help me answer the following question. 
                    Keep your answers short, and informative.
                    Your answer should be a single paragraph.""",
                     "PITFALL": """You are a chatbot evaluating github repositories, their python codes and corresponding readme files. 
                     You are looking for common pitfalls in the code. 
                     Keep in mind, that the code you are provided is only one of many files in the repository.
                     Keep your answer short and informative.
                     Only report serious flaws. If you don't find any, don't mention it.
                     Answer using only a single, short paragraph.
                     Only point out pitfalls if you are certain about them!
                     Pitfall #1 Design-flaws with regards to the data collection in the code."))
                     Pitfall #2 Dataset shift (e.g. sampling bias, imbalanced populations, imbalanced labels, non-stationary environments)."))
                     Pitfall #3 Confounders."))
                     Pitfall #4 Measurement errors (labelling mistakes, noisy measurements, inappropriate proxies)"))
                     Pitfall #5 Historical biases in the data used."))
                     Pitfall #6 Information leaking between the training and testing data."))
                     Pitfall #7 Model-problem mismatch (e.g. over-complicated/simplistic model, computational challenges)"))
                     Pitfall #8 Overfitting in the code (e.g. high variance, high complexity, low bias)."))
                     Pitfall #9 Misused metrics in the code (e.g. poor metric selection, poor implementations)"))
                     Pitfall #10 Black box models in the code (e.g. lack of interpretability, lack of transparency)"))
                     Pitfall #11 Baseline comparison issues (e.g. if the testing data does not fit the training data)"))
                     Pitfall #12 Insufficient reporting in the code (e.g. missing hyperparameters, missing evaluation metrics)"))
                     Pitfall #13 Faulty interpretations of the reported results.""" }

class LocalLLM():
  def __init__(self, model_name):
    self.pipe = pipeline("text-generation", model=model_name, max_new_tokens=1000, device=0, pad_token_id=128001)

  def predict(self, response_type, prompt):
    messages = [
        {"role": "system", "content": system_messages[response_type]},
        {"role": "user", "content": prompt},
    ]
    res = self.pipe(messages)
    res = res[0]["generated_text"]

    res = [response for response in res if response["role"] == "assistant"][0]["content"]
    res = res.strip()

    return res

class RemoteLLM():
  def __init__(self, model_name):
    token = os.getenv("hfToken")
    self.model_name = model_name
    self.client = InferenceClient(api_key=token)

  def predict(self, response_type, prompt):
    message = self.client.chat_completion(
        model=self.model_name, max_tokens=500, stream=False,
        messages=[{"role": "system", "content": system_messages[response_type]}, 
                  {"role": "user", "content": prompt}])
    return message['choices'][0]['message']['content']