File size: 1,882 Bytes
7e5b35f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
from torchvision import transforms

# MODELES
INGREDIENT_MODEL_ID = "stchakman/Fridge_Items_Model"
RECIPE_MODEL_ID = "flax-community/t5-recipe-generation"

# PIPELINES
ingredient_classifier = pipeline(
    "image-classification",
    model=INGREDIENT_MODEL_ID,
    device=0 if torch.cuda.is_available() else -1,
    top_k=4
)

recipe_generator = pipeline(
    "text2text-generation",
    model=RECIPE_MODEL_ID,
    device=0 if torch.cuda.is_available() else -1
)

# AUGMENTATION
augment = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
])

# FONCTION PRINCIPALE
def generate_recipe(image: Image.Image):
    yield "🔄 Traitement de l'image... Veuillez patienter."

    # Augmentation
    image_aug = augment(image)

    # Classification
    results = ingredient_classifier(image_aug)
    ingredients = [res["label"] for res in results]
    ingredient_str = ", ".join(ingredients)

    yield f"🥕 Ingrédients détectés : {ingredient_str}\n\n🍳 Génération de la recette..."
    prompt = f"Ingredients: {ingredient_str}. Recipe:"
    recipe = recipe_generator(prompt, max_new_tokens=256, do_sample=True)[0]["generated_text"]
    yield f"### 🥕 Ingrédients détectés :\n{ingredient_str}\n\n### 🍽️ Recette générée :\n{recipe}"

# INTERFACE
interface = gr.Interface(
    fn=generate_recipe,
    inputs=gr.Image(type="pil", label="📷 Image de vos ingrédients"),
    outputs=gr.Markdown(),
    title="🥕 Générateur de Recettes 🧑‍🍳",
    description="Dépose une image d'ingrédients pour obtenir une recette automatiquement générée à partir d'un modèle IA.",
    allow_flagging="never"
)

if __name__ == "__main__":
    interface.launch(share=True)