EmTpro01 commited on
Commit
3eceaf9
·
verified ·
1 Parent(s): 7028e54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
3
+
4
+ # Load model and tokenizer
5
+ MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
6
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
7
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
8
+
9
+ prefix = "items: "
10
+ generation_kwargs = {
11
+ "max_length": 512,
12
+ "min_length": 64,
13
+ "no_repeat_ngram_size": 3,
14
+ "do_sample": True,
15
+ "top_k": 60,
16
+ "top_p": 0.95
17
+ }
18
+
19
+ special_tokens = tokenizer.all_special_tokens
20
+ tokens_map = {
21
+ "<sep>": "--",
22
+ "<section>": "\n"
23
+ }
24
+
25
+ def skip_special_tokens(text, special_tokens):
26
+ for token in special_tokens:
27
+ text = text.replace(token, "")
28
+ return text
29
+
30
+ def target_postprocessing(texts, special_tokens):
31
+ if not isinstance(texts, list):
32
+ texts = [texts]
33
+
34
+ new_texts = []
35
+ for text in texts:
36
+ text = skip_special_tokens(text, special_tokens)
37
+ for k, v in tokens_map.items():
38
+ text = text.replace(k, v)
39
+ new_texts.append(text)
40
+ return new_texts
41
+
42
+ def generation_function(texts):
43
+ _inputs = texts if isinstance(texts, list) else [texts]
44
+ inputs = [prefix + inp for inp in _inputs]
45
+ inputs = tokenizer(
46
+ inputs,
47
+ max_length=256,
48
+ padding="max_length",
49
+ truncation=True,
50
+ return_tensors="jax"
51
+ )
52
+ input_ids = inputs.input_ids
53
+ attention_mask = inputs.attention_mask
54
+
55
+ output_ids = model.generate(
56
+ input_ids=input_ids,
57
+ attention_mask=attention_mask,
58
+ **generation_kwargs
59
+ )
60
+ generated = output_ids.sequences
61
+ generated_recipe = target_postprocessing(
62
+ tokenizer.batch_decode(generated, skip_special_tokens=False),
63
+ special_tokens
64
+ )
65
+ return generated_recipe
66
+
67
+ # Streamlit app interface
68
+ st.title("Recipe Generation from Ingredients")
69
+
70
+ # User input for ingredients
71
+ ingredients = st.text_area("Enter ingredients (comma separated):", "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn")
72
+
73
+ # Button to generate recipe
74
+ if st.button("Generate Recipe"):
75
+ if ingredients:
76
+ items = [ingredients]
77
+ generated = generation_function(items)
78
+ for text in generated:
79
+ sections = text.split("\n")
80
+ for section in sections:
81
+ section = section.strip()
82
+ if section.startswith("title:"):
83
+ section = section.replace("title:", "")
84
+ headline = "TITLE"
85
+ elif section.startswith("ingredients:"):
86
+ section = section.replace("ingredients:", "")
87
+ headline = "INGREDIENTS"
88
+ elif section.startswith("directions:"):
89
+ section = section.replace("directions:", "")
90
+ headline = "DIRECTIONS"
91
+
92
+ if headline == "TITLE":
93
+ st.subheader(f"[{headline}]: {section.strip().capitalize()}")
94
+ else:
95
+ section_info = [f" - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
96
+ st.write(f"[{headline}]:")
97
+ st.write("\n".join(section_info))
98
+ st.write("-" * 130)
99
+ else:
100
+ st.warning("Please enter ingredients.")