WhotookNima commited on
Commit
5c7fb6f
·
verified ·
1 Parent(s): f3ec334

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -78
app.py CHANGED
@@ -1,85 +1,27 @@
1
- from fastapi import FastAPI
2
- import json
3
- import difflib
4
- from flair.models import SequenceTagger
5
- from flair.data import Sentence
6
  import torch
7
 
8
  app = FastAPI()
9
 
10
- # Workaround för PyTorch 2.6 weights_only issue
11
- original_load = torch.load
12
- def patched_load(*args, **kwargs):
13
- kwargs["weights_only"] = False
14
- return original_load(*args, **kwargs)
15
 
16
- # Ladda Flair flerspråkig NER-modell
17
- try:
18
- torch.load = patched_load
19
- tagger = SequenceTagger.load("flair/ner-multi")
20
- torch.load = original_load # Återställ original torch.load
21
- except Exception as e:
22
- print(f"Error loading model: {str(e)}")
23
- raise e
24
 
25
- # Ladda entiteter från entities.json
26
- with open("entities.json") as f:
27
- entities = json.load(f)
28
- ITEMS = set(entities["items"])
29
- COLORS = set(entities["colors"])
30
- PRICES = set(entities["prices"])
31
 
32
- def correct_spelling(word, valid_words, threshold=0.7):
33
- """Korrigera stavfel."""
34
- normalized = word.rstrip("etn")
35
- matches = difflib.get_close_matches(normalized, valid_words, n=1, cutoff=threshold)
36
- return matches[0] if matches else word
37
-
38
- @app.post("/parse")
39
- async def parse_user_request(request: str):
40
- if not request or len(request) > 200:
41
- return {"error": "Ogiltig eller för lång begäran"}
42
- try:
43
- # Skapa Flair Sentence
44
- sentence = Sentence(request)
45
-
46
- # Prediktera NER-taggar
47
- tagger.predict(sentence)
48
-
49
- # Extrahera entiteter
50
- result_entities = {}
51
- # Kolla färger och priser i hela meningen
52
- words = request.lower().split()
53
- for word in words:
54
- if word in COLORS:
55
- result_entities["FÄRG"] = word
56
- elif word in PRICES:
57
- result_entities["PRIS"] = word
58
-
59
- # Extrahera varor från NER
60
- for entity in sentence.get_spans("ner"):
61
- if entity.tag in ["MISC", "ORG", "LOC"]: # Diverse, organisationer, platser som potentiella objekt
62
- corrected = correct_spelling(entity.text.lower(), ITEMS)
63
- if corrected in ITEMS:
64
- result_entities["VARA"] = corrected
65
- elif not result_entities.get("VARA"):
66
- result_entities["VARA"] = entity.text.lower()
67
-
68
- # Om ingen vara hittades
69
- if "VARA" not in result_entities:
70
- return {"result": "error:ingen vara"}
71
-
72
- # Skapa strukturerad sträng
73
- result_parts = [f"vara:{result_entities['VARA']}"]
74
- if "FÄRG" in result_entities:
75
- result_parts.append(f"färg:{result_entities['FÄRG']}")
76
- if "PRIS" in result_entities:
77
- result_parts.append(f"pris:{result_entities['PRIS']}")
78
-
79
- return {"result": ",".join(result_parts)}
80
- except Exception as e:
81
- return {"error": f"Fel vid parsning: {str(e)}"}
82
-
83
- @app.get("/")
84
- async def root():
85
- return {"message": "Request Parser API is running!"}
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
4
  import torch
5
 
6
  app = FastAPI()
7
 
8
+ # Ladda modellen
9
+ model_id = "AI-Sweden/gpt-sw3-126m"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ model = AutoModelForCausalLM.from_pretrained(model_id)
 
12
 
13
+ # Om du kör på CPU – lägg till detta
14
+ device = torch.device("cpu")
15
+ model.to(device)
 
 
 
 
 
16
 
17
+ # Input-modell
18
+ class Prompt(BaseModel):
19
+ text: str
20
+ max_new_tokens: int = 50
 
 
21
 
22
+ @app.post("/generate")
23
+ async def generate_text(prompt: Prompt):
24
+ inputs = tokenizer(prompt.text, return_tensors="pt").to(device)
25
+ outputs = model.generate(**inputs, max_new_tokens=prompt.max_new_tokens)
26
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+ return {"response": generated}