import streamlit as st from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer # Load model and tokenizer MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) prefix = "items: " generation_kwargs = { "max_length": 512, "min_length": 64, "no_repeat_ngram_size": 3, "do_sample": True, "top_k": 60, "top_p": 0.95 } special_tokens = tokenizer.all_special_tokens tokens_map = { "": "--", "
": "\n" } def skip_special_tokens(text, special_tokens): for token in special_tokens: text = text.replace(token, "") return text def target_postprocessing(texts, special_tokens): if not isinstance(texts, list): texts = [texts] new_texts = [] for text in texts: text = skip_special_tokens(text, special_tokens) for k, v in tokens_map.items(): text = text.replace(k, v) new_texts.append(text) return new_texts def generation_function(texts): _inputs = texts if isinstance(texts, list) else [texts] inputs = [prefix + inp for inp in _inputs] inputs = tokenizer( inputs, max_length=256, padding="max_length", truncation=True, return_tensors="jax" ) input_ids = inputs.input_ids attention_mask = inputs.attention_mask output_ids = model.generate( input_ids=input_ids, attention_mask=attention_mask, **generation_kwargs ) generated = output_ids.sequences generated_recipe = target_postprocessing( tokenizer.batch_decode(generated, skip_special_tokens=False), special_tokens ) return generated_recipe # Streamlit app interface st.title("Recipe Generation from Ingredients") # User input for ingredients ingredients = st.text_area("Enter ingredients (comma separated):", "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn") # Button to generate recipe if st.button("Generate Recipe"): if ingredients: items = [ingredients] generated = generation_function(items) for text in generated: sections = text.split("\n") for section in sections: section = section.strip() if section.startswith("title:"): section = section.replace("title:", "") headline = "TITLE" elif section.startswith("ingredients:"): section = section.replace("ingredients:", "") headline = "INGREDIENTS" elif section.startswith("directions:"): section = section.replace("directions:", "") headline = "DIRECTIONS" if headline == "TITLE": st.subheader(f"[{headline}]: {section.strip().capitalize()}") else: section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))] st.write(f"[{headline}]:") st.write("\n".join(section_info)) st.write("-" * 130) else: st.warning("Please enter ingredients.")