Spaces:
Running
Running
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer, AutoModel | |
import torch | |
import numpy as np | |
import random | |
import json | |
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from datetime import datetime, timedelta | |
bert_model_name = "alexdseo/RecipeBERT" | |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
bert_model = AutoModel.from_pretrained(bert_model_name) | |
MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" | |
t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) | |
t5_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) | |
special_tokens = t5_tokenizer.all_special_tokens | |
tokens_map = { | |
"<sep>": "--", | |
"<section>": "\n" | |
} | |
# --- RecipeBERT-spezifische Funktionen --- | |
def get_embedding(text): | |
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens""" | |
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = bert_model(**inputs) | |
attention_mask = inputs['attention_mask'] | |
token_embeddings = outputs.last_hidden_state | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return (sum_embeddings / sum_mask).squeeze(0) | |
def format_ingredients_for_bert(ingredients_list): | |
"""Formatiert Zutatenliste für BERT""" | |
return f"Ingredients: {', '.join(ingredients_list)}" | |
def normalize_ingredient_name(name): | |
return name.strip().lower() | |
def get_cosine_similarity(vec1, vec2): | |
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren""" | |
if torch.is_tensor(vec1): | |
vec1 = vec1.detach().numpy() | |
if torch.is_tensor(vec2): | |
vec2 = vec2.detach().numpy() | |
vec1 = vec1.flatten() | |
vec2 = vec2.flatten() | |
dot_product = np.dot(vec1, vec2) | |
norm_a = np.linalg.norm(vec1) | |
norm_b = np.linalg.norm(vec2) | |
if norm_a == 0 or norm_b == 0: | |
return 0 | |
return dot_product / (norm_a * norm_b) | |
def calculate_age_bonus(date_added_str: str, category: str) -> float: | |
""" | |
Berechnet einen prozentualen Bonus basierend auf dem Alter der Zutat. | |
- Standard: 0.5% pro Tag, max. 10%. | |
- Gemüse: 2.0% pro Tag, max. 10%. | |
""" | |
try: | |
# Handle 'Z' for UTC and parse to datetime object | |
date_added = datetime.fromisoformat(date_added_str.replace('Z', '+00:00')) | |
except ValueError: | |
print(f"Warning: Could not parse date_added_str: {date_added_str}. Returning 0 bonus.") | |
return 0.0 | |
today = datetime.now() | |
days_since_added = (today - date_added).days | |
if days_since_added < 0: # Zutat aus der Zukunft? Ungültig. | |
return 0.0 | |
if category and category.lower() == "vegetables": | |
daily_bonus = 0.02 # 2% pro Tag für Gemüse | |
else: | |
daily_bonus = 0.005 # 0.5% pro Tag für andere | |
bonus = days_since_added * daily_bonus | |
return min(bonus, 0.10) # Max 10% (0.10) | |
def find_best_ingredients(required_ingredients_names, available_ingredients_details, max_ingredients=6): | |
""" | |
Findet die besten Zutaten basierend auf RecipeBERT Embeddings | |
required_ingredients_names: Liste von Strings (nur Namen) | |
available_ingredients_details: Liste von IngredientDetail-Objekten | |
""" | |
required_ingredients_names = list(set(required_ingredients_names)) | |
# Filtern der verfügbaren Zutaten, um sicherzustellen, dass keine Pflichtzutaten dabei sind | |
available_ingredients_filtered_details = [ | |
item for item in available_ingredients_details | |
if item.name not in required_ingredients_names | |
] | |
# Wenn keine Pflichtzutaten vorhanden sind, aber verfügbare, wähle eine zufällig als Pflichtzutat | |
if not required_ingredients_names and available_ingredients_filtered_details: | |
random_item = random.choice(available_ingredients_filtered_details) | |
required_ingredients_names = [random_item.name] | |
# Entferne die zufällig gewählte Zutat aus den verfügbaren Details | |
available_ingredients_filtered_details = [ | |
item for item in available_ingredients_filtered_details | |
if item.name != random_item.name | |
] | |
print(f"No required ingredients provided. Randomly selected: {required_ingredients_names[0]}") | |
if not required_ingredients_names or len(required_ingredients_names) >= max_ingredients: | |
return required_ingredients_names[:max_ingredients] | |
if not available_ingredients_filtered_details: | |
return required_ingredients_names | |
print(f"\n=== Suche passende Zutaten für Basis: {required_ingredients_names} ===") | |
print(f"Verfügbare Zutaten: {[item.name for item in available_ingredients_filtered_details]}") | |
print("-" * 50) | |
current_combination = required_ingredients_names.copy() | |
remaining_ingredients_details = available_ingredients_filtered_details.copy() | |
# Entferne Duplikate aus remaining_ingredients_details - nur eine Zutat pro Name | |
seen_names = set() | |
unique_remaining_ingredients = [] | |
for item in remaining_ingredients_details: | |
if item.name not in seen_names: | |
unique_remaining_ingredients.append(item) | |
seen_names.add(item.name) | |
remaining_ingredients_details = unique_remaining_ingredients | |
num_to_add = min(max_ingredients - len(required_ingredients_names), len(remaining_ingredients_details)) | |
for round_num in range(num_to_add): | |
best_ingredient_detail = None | |
best_score = -1 | |
# Formatiere aktuelle Kombination für BERT | |
current_text = format_ingredients_for_bert(current_combination) | |
current_embedding = get_embedding(current_text) | |
print(f"\nRunde {round_num + 1} - Aktuelle Kombination: {current_combination}") | |
print("Teste verbleibende Zutaten:") | |
for ingredient_detail in remaining_ingredients_details: | |
# Berechne semantische Ähnlichkeit mit BERT | |
ingredient_text = format_ingredients_for_bert([ingredient_detail.name]) | |
ingredient_embedding = get_embedding(ingredient_text) | |
similarity = get_cosine_similarity(current_embedding, ingredient_embedding) | |
# Berechne Altersbonus | |
age_bonus = calculate_age_bonus(ingredient_detail.dateAdded, ingredient_detail.category) | |
# Kombiniere Ähnlichkeit und Altersbonus | |
final_score = similarity + age_bonus | |
print(f" - '{ingredient_detail.name}': Ähnlichkeit = {similarity:.4f}, Altersbonus = {age_bonus:.4f}, Gesamt = {final_score:.4f}") | |
if final_score > best_score: | |
best_score = final_score | |
best_ingredient_detail = ingredient_detail | |
if best_ingredient_detail: | |
current_combination.append(best_ingredient_detail.name) | |
remaining_ingredients_details.remove(best_ingredient_detail) | |
# Berechne die Komponenten für die Ausgabe | |
best_similarity = get_cosine_similarity( | |
current_embedding, | |
get_embedding(format_ingredients_for_bert([best_ingredient_detail.name])) | |
) | |
best_age_bonus = calculate_age_bonus(best_ingredient_detail.dateAdded, best_ingredient_detail.category) | |
print(f"\n-> Runde {round_num + 1} abgeschlossen: Beste Zutat ist '{best_ingredient_detail.name}' mit Gesamtscore {best_score:.4f}") | |
print(f" (Ähnlichkeit: {best_similarity:.4f} + Altersbonus: {best_age_bonus:.4f})") | |
print(f" Neue Kombination: {current_combination}") | |
print("-" * 50) | |
else: | |
print("Keine weiteren passenden Zutaten gefunden.") | |
break | |
random.shuffle(current_combination) | |
print(f"\nEndgültige Zutatenkombination: {current_combination}") | |
return current_combination | |
# --- Chef Transformer-spezifische Funktionen --- | |
def skip_special_tokens(text, special_tokens): | |
"""Entfernt spezielle Tokens aus dem Text""" | |
for token in special_tokens: | |
text = text.replace(token, "") | |
return text | |
def target_postprocessing(texts, special_tokens): | |
"""Post-processed generierten Text""" | |
if not isinstance(texts, list): | |
texts = [texts] | |
new_texts = [] | |
for text in texts: | |
text = skip_special_tokens(text, special_tokens) | |
for k, v in tokens_map.items(): | |
text = text.replace(k, v) | |
new_texts.append(text) | |
return new_texts | |
def validate_recipe_ingredients(recipe_ingredients, expected_ingredients, tolerance=0): | |
""" | |
Validiert, ob das Rezept ungefähr die erwarteten Zutaten enthält. | |
""" | |
recipe_count = len([ing for ing in recipe_ingredients if ing and ing.strip()]) | |
expected_count = len(expected_ingredients) | |
return abs(recipe_count - expected_count) == tolerance | |
def generate_recipe_with_t5(ingredients_list, max_retries=5): | |
"""Generiert ein Rezept mit dem T5 Rezeptgenerierungsmodell mit Validierung.""" | |
original_ingredients = ingredients_list.copy() | |
for attempt in range(max_retries): | |
try: | |
if attempt > 0: | |
current_ingredients = original_ingredients.copy() | |
random.shuffle(current_ingredients) | |
else: | |
current_ingredients = ingredients_list | |
ingredients_string = ", ".join(current_ingredients) | |
prefix = "items: " | |
generation_kwargs = { | |
"max_length": 512, | |
"min_length": 64, | |
"do_sample": True, | |
"top_k": 60, | |
"top_p": 0.95 | |
} | |
print(f"Attempt {attempt + 1}: {prefix + ingredients_string}") # Debug-Print | |
inputs = t5_tokenizer( | |
prefix + ingredients_string, | |
max_length=256, | |
padding="max_length", | |
truncation=True, | |
return_tensors="jax" | |
) | |
output_ids = t5_model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
**generation_kwargs | |
) | |
generated = output_ids.sequences | |
generated_text = target_postprocessing( | |
t5_tokenizer.batch_decode(generated, skip_special_tokens=False), | |
special_tokens | |
)[0] | |
recipe = {} | |
sections = generated_text.split("\n") | |
for section in sections: | |
section = section.strip() | |
if section.startswith("title:"): | |
recipe["title"] = section.replace("title:", "").strip().capitalize() | |
elif section.startswith("ingredients:"): | |
ingredients_text = section.replace("ingredients:", "").strip() | |
recipe["ingredients"] = [item.strip().capitalize() for item in ingredients_text.split("--") if item.strip()] | |
elif section.startswith("directions:"): | |
directions_text = section.replace("directions:", "").strip() | |
recipe["directions"] = [step.strip().capitalize() for step in directions_text.split("--") if step.strip()] | |
if "title" not in recipe: | |
recipe["title"] = f"Rezept mit {', '.join(current_ingredients[:3])}" | |
if "ingredients" not in recipe: | |
recipe["ingredients"] = current_ingredients | |
if "directions" not in recipe: | |
recipe["directions"] = ["Keine Anweisungen generiert"] | |
if validate_recipe_ingredients(recipe["ingredients"], original_ingredients): | |
print(f"Success on attempt {attempt + 1}: Recipe has correct number of ingredients") # Debug-Print | |
return recipe | |
else: | |
print(f"Attempt {attempt + 1} failed: Expected {len(original_ingredients)} ingredients, got {len(recipe['ingredients'])}") # Debug-Print | |
if attempt == max_retries - 1: | |
print("Max retries reached, returning last generated recipe") # Debug-Print | |
return recipe | |
except Exception as e: | |
print(f"Error in recipe generation attempt {attempt + 1}: {str(e)}") # Debug-Print | |
if attempt == max_retries - 1: | |
return { | |
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", | |
"ingredients": original_ingredients, | |
"directions": ["Fehler beim Generieren der Rezeptanweisungen"] | |
} | |
return { | |
"title": f"Rezept mit {original_ingredients[0] if original_ingredients else 'Zutaten'}", | |
"ingredients": original_ingredients, | |
"directions": ["Fehler beim Generieren der Rezeptanweisungen"] | |
} | |
def process_recipe_request_logic(required_ingredients, available_ingredients_details, max_ingredients, max_retries): | |
""" | |
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. | |
available_ingredients_details: Liste von IngredientDetail-Objekten | |
""" | |
if not required_ingredients and not available_ingredients_details: | |
return {"error": "Keine Zutaten angegeben"} | |
try: | |
optimized_ingredients = find_best_ingredients( | |
required_ingredients, available_ingredients_details, max_ingredients | |
) | |
recipe = generate_recipe_with_t5(optimized_ingredients, max_retries) | |
result = { | |
'title': recipe['title'], | |
'ingredients': recipe['ingredients'], | |
'directions': recipe['directions'], | |
'used_ingredients': optimized_ingredients | |
} | |
return result | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} | |
# --- FastAPI-Implementierung --- | |
app = FastAPI(title="AI Recipe Generator API") | |
class IngredientDetail(BaseModel): | |
name: str | |
dateAdded: str | |
category: str | |
class RecipeRequest(BaseModel): | |
required_ingredients: list[str] = [] | |
available_ingredients: list[IngredientDetail] = [] | |
max_ingredients: int = 7 | |
max_retries: int = 5 | |
ingredients: list[str] = [] | |
async def generate_recipe_api(request_data: RecipeRequest): | |
""" | |
Standard-REST-API-Endpunkt für die Flutter-App. | |
Nimmt direkt JSON-Daten an und gibt direkt JSON zurück. | |
""" | |
final_required_ingredients = request_data.required_ingredients | |
if not final_required_ingredients and request_data.ingredients: | |
final_required_ingredients = request_data.ingredients | |
result_dict = process_recipe_request_logic( | |
final_required_ingredients, | |
request_data.available_ingredients, | |
request_data.max_ingredients, | |
request_data.max_retries | |
) | |
return JSONResponse(content=result_dict) | |
async def read_root(): | |
return {"message": "AI Recipe Generator API is running (FastAPI only)!"} | |