shreyanshjha0709 commited on
Commit
3da548a
·
verified ·
1 Parent(s): 2ecd770

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -47,9 +47,10 @@ if selected_brand != "Select":
47
 
48
  if watch_data:
49
  # Generation parameters
50
- max_length = st.slider("Max Length", min_value=50, max_value=300, value=150)
51
  temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1)
52
-
 
53
  # Generate description based on attributes
54
  if st.button("Generate Description"):
55
  attributes = {
@@ -61,7 +62,7 @@ if selected_brand != "Select":
61
  "movement": watch_data.get("movement", "Unknown Movement"),
62
  "gender": watch_data.get("gender", "Unknown Gender"),
63
  }
64
-
65
  # Format input similar to training data (adjust this based on your training data format)
66
  input_text = f"""Generate a detailed description for the following watch:
67
  Brand: {attributes['brand']}
@@ -73,26 +74,27 @@ Movement: {attributes['movement']}
73
  Gender: {attributes['gender']}
74
 
75
  Description:"""
76
-
77
  # Tokenize input and generate description
78
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
79
  outputs = model.generate(
80
- **inputs,
81
- max_length=max_length,
82
  num_return_sequences=1,
83
  temperature=temperature,
84
  top_k=50,
85
  top_p=0.95,
86
- do_sample=True
 
87
  )
88
-
89
  # Decode generated text
90
  description = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
-
92
  # Display the result
93
  st.write("### Generated Description")
94
  st.write(description)
95
-
96
  # Display watch details
97
  st.write("### Watch Details")
98
  st.json(json.dumps(watch_data, indent=2))
@@ -130,4 +132,4 @@ st.markdown(
130
  </div>
131
  """,
132
  unsafe_allow_html=True
133
- )
 
47
 
48
  if watch_data:
49
  # Generation parameters
50
+ max_length = st.slider("Max Length", min_value=100, max_value=500, value=300) # Increase max length to allow for more text
51
  temperature = st.slider("Temperature", min_value=0.1, max_value=1.0, value=0.7, step=0.1)
52
+ repetition_penalty = st.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1) # Penalize repetitive text
53
+
54
  # Generate description based on attributes
55
  if st.button("Generate Description"):
56
  attributes = {
 
62
  "movement": watch_data.get("movement", "Unknown Movement"),
63
  "gender": watch_data.get("gender", "Unknown Gender"),
64
  }
65
+
66
  # Format input similar to training data (adjust this based on your training data format)
67
  input_text = f"""Generate a detailed description for the following watch:
68
  Brand: {attributes['brand']}
 
74
  Gender: {attributes['gender']}
75
 
76
  Description:"""
77
+
78
  # Tokenize input and generate description
79
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
80
  outputs = model.generate(
81
+ **inputs,
82
+ max_length=max_length,
83
  num_return_sequences=1,
84
  temperature=temperature,
85
  top_k=50,
86
  top_p=0.95,
87
+ do_sample=True,
88
+ repetition_penalty=repetition_penalty # Prevent repetitive descriptions
89
  )
90
+
91
  # Decode generated text
92
  description = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
+
94
  # Display the result
95
  st.write("### Generated Description")
96
  st.write(description)
97
+
98
  # Display watch details
99
  st.write("### Watch Details")
100
  st.json(json.dumps(watch_data, indent=2))
 
132
  </div>
133
  """,
134
  unsafe_allow_html=True
135
+ )