Santiagogrz6 commited on
Commit
dcdbe16
verified
1 Parent(s): 94ca4fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -9,10 +9,10 @@ torch.set_num_interop_threads(1)
9
  tokenizer = T5Tokenizer.from_pretrained("cssupport/t5-small-awesome-text-to-sql")
10
  model = T5ForConditionalGeneration.from_pretrained("cssupport/t5-small-awesome-text-to-sql")
11
  # Esquema de base de datos
12
- SCHEMA = """
13
  Database schema:
14
- Table bodegas(Id, Nombre, Encargado, Telefono, Email, Direccion, Horario, Regional, Latitud, Longitud)
15
- Table maestra(CodigoSap, Descripcion, Grupo, Agrupador, Marca, Parte, Operacion, Componente)
16
  """
17
 
18
  # Funci贸n principal
@@ -21,7 +21,7 @@ def generar_sql(pregunta_espanol):
21
  # Traducir pregunta a ingl茅s
22
  pregunta_ingles = GoogleTranslator(source="es", target="en").translate(pregunta_espanol)
23
  # Crear prompt
24
- prompt = f"{SCHEMA}\ntranslate English to SQL: {pregunta_ingles}"
25
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
26
  output = model.generate(input_ids, max_length=128)
27
  sql = tokenizer.decode(output[0], skip_special_tokens=True)
 
9
  tokenizer = T5Tokenizer.from_pretrained("cssupport/t5-small-awesome-text-to-sql")
10
  model = T5ForConditionalGeneration.from_pretrained("cssupport/t5-small-awesome-text-to-sql")
11
  # Esquema de base de datos
12
+ SCHEMA_EN = """
13
  Database schema:
14
+ Table bodegas(Id, Name, Manager, Phone, Email, Address, Schedule, Region, Latitude, Longitude)
15
+ Table master(CodigoSap, Description, Group, Aggregator, Brand, Part, Operation, Component)
16
  """
17
 
18
  # Funci贸n principal
 
21
  # Traducir pregunta a ingl茅s
22
  pregunta_ingles = GoogleTranslator(source="es", target="en").translate(pregunta_espanol)
23
  # Crear prompt
24
+ prompt = f"Given the following schema:\n{SCHEMA_EN}\nTranslate this question into SQL: {pregunta_espanol}"
25
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
26
  output = model.generate(input_ids, max_length=128)
27
  sql = tokenizer.decode(output[0], skip_special_tokens=True)