EmTpro01 commited on
Commit
2ab4365
·
verified ·
1 Parent(s): 6b918d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -92
app.py CHANGED
@@ -1,100 +1,71 @@
1
  import streamlit as st
2
- from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
 
3
 
4
- # Load model and tokenizer
5
- MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
6
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
7
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
8
-
9
- prefix = "items: "
10
- generation_kwargs = {
11
- "max_length": 512,
12
- "min_length": 64,
13
- "no_repeat_ngram_size": 3,
14
- "do_sample": True,
15
- "top_k": 60,
16
- "top_p": 0.95
17
- }
18
-
19
- special_tokens = tokenizer.all_special_tokens
20
- tokens_map = {
21
- "<sep>": "--",
22
- "<section>": "\n"
23
- }
24
-
25
- def skip_special_tokens(text, special_tokens):
26
- for token in special_tokens:
27
- text = text.replace(token, "")
28
- return text
29
-
30
- def target_postprocessing(texts, special_tokens):
31
- if not isinstance(texts, list):
32
- texts = [texts]
33
 
34
- new_texts = []
35
- for text in texts:
36
- text = skip_special_tokens(text, special_tokens)
37
- for k, v in tokens_map.items():
38
- text = text.replace(k, v)
39
- new_texts.append(text)
40
- return new_texts
41
 
42
- def generation_function(texts):
43
- _inputs = texts if isinstance(texts, list) else [texts]
44
- inputs = [prefix + inp for inp in _inputs]
45
- inputs = tokenizer(
46
- inputs,
47
- max_length=256,
48
- padding="max_length",
49
- truncation=True,
50
- return_tensors="jax"
51
- )
52
- input_ids = inputs.input_ids
53
- attention_mask = inputs.attention_mask
 
 
 
 
 
 
 
 
 
 
54
 
55
- output_ids = model.generate(
56
- input_ids=input_ids,
57
- attention_mask=attention_mask,
58
- **generation_kwargs
59
- )
60
- generated = output_ids.sequences
61
- generated_recipe = target_postprocessing(
62
- tokenizer.batch_decode(generated, skip_special_tokens=False),
63
- special_tokens
64
  )
65
- return generated_recipe
66
-
67
- # Streamlit app interface
68
- st.title("Recipe Generation from Ingredients")
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # User input for ingredients
71
- ingredients = st.text_area("Enter ingredients (comma separated):", "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn")
 
72
 
73
- # Button to generate recipe
74
- if st.button("Generate Recipe"):
75
- if ingredients:
76
- items = [ingredients]
77
- generated = generation_function(items)
78
- for text in generated:
79
- sections = text.split("\n")
80
- for section in sections:
81
- section = section.strip()
82
- if section.startswith("title:"):
83
- section = section.replace("title:", "")
84
- headline = "TITLE"
85
- elif section.startswith("ingredients:"):
86
- section = section.replace("ingredients:", "")
87
- headline = "INGREDIENTS"
88
- elif section.startswith("directions:"):
89
- section = section.replace("directions:", "")
90
- headline = "DIRECTIONS"
91
-
92
- if headline == "TITLE":
93
- st.subheader(f"[{headline}]: {section.strip().capitalize()}")
94
- else:
95
- section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
96
- st.write(f"[{headline}]:")
97
- st.write("\n".join(section_info))
98
- st.write("-" * 130)
99
- else:
100
- st.warning("Please enter ingredients.")
 
1
  import streamlit as st
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+ import torch
4
 
5
+ # Load the model and tokenizer with CPU optimization
6
+ @st.cache_resource
7
+ def load_model():
8
+ model_name = "flax-community/t5-recipe-generation"
9
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
10
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Explicitly set to CPU and use float32 to reduce memory usage
13
+ model = model.to('cpu').float()
14
+
15
+ return tokenizer, model
 
 
 
16
 
17
+ # Generate recipe function with CPU-friendly generation
18
+ def generate_recipe(ingredients, tokenizer, model, max_length=512):
19
+ # Prepare input
20
+ input_text = f"Generate recipe with: {ingredients}"
21
+
22
+ # Use torch no_grad to reduce memory consumption
23
+ with torch.no_grad():
24
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
25
+
26
+ # Adjust generation parameters for faster CPU inference
27
+ output_ids = model.generate(
28
+ input_ids,
29
+ max_length=max_length,
30
+ num_return_sequences=1,
31
+ no_repeat_ngram_size=2,
32
+ num_beams=4, # Reduced beam search for faster CPU processing
33
+ early_stopping=True
34
+ )
35
+
36
+ # Decode and clean the output
37
+ recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
+ return recipe
39
 
40
+ # Streamlit app
41
+ def main():
42
+ st.title("🍳 AI Recipe Generator")
43
+
44
+ # Sidebar for input
45
+ st.sidebar.header("Ingredient Input")
46
+ ingredients_input = st.sidebar.text_area(
47
+ "Enter ingredients (comma-separated):",
48
+ placeholder="e.g. chicken, tomatoes, onions, garlic"
49
  )
50
+
51
+ # Load model
52
+ tokenizer, model = load_model()
53
+
54
+ # Generate button
55
+ if st.sidebar.button("Generate Recipe"):
56
+ if ingredients_input:
57
+ with st.spinner("Generating recipe..."):
58
+ recipe = generate_recipe(ingredients_input, tokenizer, model)
59
+
60
+ # Display recipe sections
61
+ st.subheader("🥘 Generated Recipe")
62
+ st.write(recipe)
63
+ else:
64
+ st.warning("Please enter some ingredients!")
65
 
66
+ # Additional UI elements
67
+ st.sidebar.markdown("---")
68
+ st.sidebar.info("Enter ingredients and click 'Generate Recipe'")
69
 
70
+ if __name__ == "__main__":
71
+ main()