Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -47,9 +47,10 @@ if selected_brand != "Select":
|
|
47 |
|
48 |
if watch_data:
|
49 |
# Generation parameters
|
50 |
-
max_length = st.slider("Max Length", min_value=
|
51 |
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1)
|
52 |
-
|
|
|
53 |
# Generate description based on attributes
|
54 |
if st.button("Generate Description"):
|
55 |
attributes = {
|
@@ -61,7 +62,7 @@ if selected_brand != "Select":
|
|
61 |
"movement": watch_data.get("movement", "Unknown Movement"),
|
62 |
"gender": watch_data.get("gender", "Unknown Gender"),
|
63 |
}
|
64 |
-
|
65 |
# Format input similar to training data (adjust this based on your training data format)
|
66 |
input_text = f"""Generate a detailed description for the following watch:
|
67 |
Brand: {attributes['brand']}
|
@@ -73,26 +74,27 @@ Movement: {attributes['movement']}
|
|
73 |
Gender: {attributes['gender']}
|
74 |
|
75 |
Description:"""
|
76 |
-
|
77 |
# Tokenize input and generate description
|
78 |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
79 |
outputs = model.generate(
|
80 |
-
**inputs,
|
81 |
-
max_length=max_length,
|
82 |
num_return_sequences=1,
|
83 |
temperature=temperature,
|
84 |
top_k=50,
|
85 |
top_p=0.95,
|
86 |
-
do_sample=True
|
|
|
87 |
)
|
88 |
-
|
89 |
# Decode generated text
|
90 |
description = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
91 |
-
|
92 |
# Display the result
|
93 |
st.write("### Generated Description")
|
94 |
st.write(description)
|
95 |
-
|
96 |
# Display watch details
|
97 |
st.write("### Watch Details")
|
98 |
st.json(json.dumps(watch_data, indent=2))
|
@@ -130,4 +132,4 @@ st.markdown(
|
|
130 |
</div>
|
131 |
""",
|
132 |
unsafe_allow_html=True
|
133 |
-
)
|
|
|
47 |
|
48 |
if watch_data:
|
49 |
# Generation parameters
|
50 |
+
max_length = st.slider("Max Length", min_value=100, max_value=500, value=300) # Increase max length to allow for more text
|
51 |
temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1)
|
52 |
+
repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1) # Penalize repetitive text
|
53 |
+
|
54 |
# Generate description based on attributes
|
55 |
if st.button("Generate Description"):
|
56 |
attributes = {
|
|
|
62 |
"movement": watch_data.get("movement", "Unknown Movement"),
|
63 |
"gender": watch_data.get("gender", "Unknown Gender"),
|
64 |
}
|
65 |
+
|
66 |
# Format input similar to training data (adjust this based on your training data format)
|
67 |
input_text = f"""Generate a detailed description for the following watch:
|
68 |
Brand: {attributes['brand']}
|
|
|
74 |
Gender: {attributes['gender']}
|
75 |
|
76 |
Description:"""
|
77 |
+
|
78 |
# Tokenize input and generate description
|
79 |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
80 |
outputs = model.generate(
|
81 |
+
**inputs,
|
82 |
+
max_length=max_length,
|
83 |
num_return_sequences=1,
|
84 |
temperature=temperature,
|
85 |
top_k=50,
|
86 |
top_p=0.95,
|
87 |
+
do_sample=True,
|
88 |
+
repetition_penalty=repetition_penalty # Prevent repetitive descriptions
|
89 |
)
|
90 |
+
|
91 |
# Decode generated text
|
92 |
description = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
93 |
+
|
94 |
# Display the result
|
95 |
st.write("### Generated Description")
|
96 |
st.write(description)
|
97 |
+
|
98 |
# Display watch details
|
99 |
st.write("### Watch Details")
|
100 |
st.json(json.dumps(watch_data, indent=2))
|
|
|
132 |
</div>
|
133 |
""",
|
134 |
unsafe_allow_html=True
|
135 |
+
)
|