acadiaway commited on
Commit
e96349e
·
1 Parent(s): b6144db
Files changed (1) hide show
  1. pipeline.py +0 -128
pipeline.py DELETED
@@ -1,128 +0,0 @@
1
- import os
2
- import requests
3
- import time
4
- import re
5
- import json
6
- from db_utils import get_schema, execute_sql
7
-
8
- def query_gemini_api(prompt, max_retries=3):
9
- """Query the Google Gemini API"""
10
- api_key = os.getenv("GOOGLE_API_KEY")
11
- if not api_key:
12
- raise ValueError("GOOGLE_API_KEY not found in environment variables")
13
-
14
- # Gemini API endpoint
15
- url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={api_key}"
16
-
17
- headers = {
18
- "Content-Type": "application/json"
19
- }
20
-
21
- payload = {
22
- "contents": [{
23
- "parts": [{
24
- "text": prompt
25
- }]
26
- }],
27
- "generationConfig": {
28
- "temperature": 0.1,
29
- "topK": 1,
30
- "topP": 0.8,
31
- "maxOutputTokens": 200,
32
- "stopSequences": ["```", "\n\n"]
33
- },
34
- "safetySettings": [
35
- {
36
- "category": "HARM_CATEGORY_HARASSMENT",
37
- "threshold": "BLOCK_NONE"
38
- },
39
- {
40
- "category": "HARM_CATEGORY_HATE_SPEECH",
41
- "threshold": "BLOCK_NONE"
42
- },
43
- {
44
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
45
- "threshold": "BLOCK_NONE"
46
- },
47
- {
48
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
49
- "threshold": "BLOCK_NONE"
50
- }
51
- ]
52
- }
53
-
54
- for attempt in range(max_retries):
55
- try:
56
- response = requests.post(url, headers=headers, json=payload, timeout=30)
57
-
58
- if response.status_code == 200:
59
- result = response.json()
60
- if "candidates" in result and len(result["candidates"]) > 0:
61
- candidate = result["candidates"][0]
62
- # Check for safety ratings first to see if it was blocked
63
- if candidate.get('finishReason') == 'SAFETY':
64
- print("=== DEBUG: Gemini response blocked due to safety settings.")
65
- return "Error: The response was blocked by safety filters."
66
- if "content" in candidate and "parts" in candidate["content"]:
67
- generated_text = candidate["content"]["parts"][0]["text"].strip()
68
- return generated_text
69
-
70
- return "No valid response generated"
71
-
72
- elif response.status_code == 429:
73
- wait_time = 60 * (attempt + 1)
74
- time.sleep(wait_time)
75
- continue
76
- else:
77
- error_msg = f"Gemini API Error {response.status_code}: {response.text}"
78
- if attempt == max_retries - 1:
79
- raise Exception(error_msg)
80
-
81
- except requests.exceptions.Timeout:
82
- if attempt == max_retries - 1:
83
- raise Exception("Request timed out after multiple attempts")
84
- time.sleep(5)
85
-
86
- except Exception as e:
87
- if attempt == max_retries - 1:
88
- raise e
89
- time.sleep(5)
90
-
91
- raise Exception("Failed to get response after all retries")
92
-
93
- def extract_user_requested_limit(nl_query):
94
- """Extract user-requested number from natural language query"""
95
- patterns = [
96
- r'\b(\d+)\s+(?:ships?|vessels?|boats?|records?|results?|entries?|names?)\b',
97
- r'(?:show|list|find|get)\s+(?:me\s+)?(?:the\s+)?(?:top\s+|first\s+)?(\d+)',
98
- r'(?:names\s+of\s+)(\d+)\s+',
99
- r'\b(\d+)\s+(?:oldest|newest|biggest|smallest|largest)',
100
- ]
101
-
102
- for pattern in patterns:
103
- match = re.search(pattern, nl_query, re.IGNORECASE)
104
- if match:
105
- return int(match.group(1))
106
- return None
107
-
108
- def clean_sql_output(sql_text, user_limit=None):
109
- """Clean and validate SQL output from the model"""
110
- sql_text = sql_text.strip()
111
-
112
- if sql_text.startswith("```"):
113
- lines = sql_text.split('\n')
114
- sql_lines = []
115
- in_sql = False
116
- for line in lines:
117
- if line.strip().startswith("```"):
118
- in_sql = not in_sql
119
- continue
120
- if in_sql:
121
- sql_lines.append(line)
122
- sql_text = '\n'.join(sql_lines)
123
-
124
- lines = sql_text.split('\n')
125
- sql = ""
126
- for line in lines:
127
- line = line.strip()
128
- if line and (line.upper().startswith('SELECT') or sql