Spaces:
Sleeping
Sleeping
import argparse | |
import requests | |
import streamlit as st | |
def parse_arguments(): | |
"""Parse command-line arguments.""" | |
parser = argparse.ArgumentParser(description="Prompt Similarity Finder") | |
parser.add_argument( | |
"--api_url", | |
type=str, | |
default="https://lazarr19-prompt-engine.hf.space", | |
help="The URL of the FastAPI service", | |
) | |
return parser.parse_args() | |
def get_similar_prompts(api_url, query, n): | |
"""Fetch similar prompts from the FastAPI service.""" | |
try: | |
response = requests.post( | |
f"{api_url}/most_similar", json={"query": query, "n": n} | |
) | |
response.raise_for_status() # Raise an exception for HTTP errors | |
return response.json() | |
except requests.RequestException as e: | |
st.error(f"Error: {e}") | |
return None | |
def get_color(score): | |
"""Determine the color based on the score.""" | |
if score >= 0.8: | |
return "green" | |
elif score >= 0.5: | |
return "orange" | |
else: | |
return "red" | |
def main(api_url): | |
"""Main function to run the Streamlit app.""" | |
st.title("Prompt Similarity Finder") | |
# User input for query | |
query = st.text_input("Enter your query:", "") | |
n = st.slider( | |
"Number of similar prompts to retrieve:", min_value=1, max_value=40, value=5 | |
) | |
if st.button("Find Similar Prompts"): | |
if query: | |
with st.spinner("Fetching similar prompts..."): | |
result = get_similar_prompts(api_url, query, n) | |
if result: | |
similar_prompts = result.get("similar_prompts", []) | |
if similar_prompts: | |
st.subheader("Similar Prompts:") | |
for item in similar_prompts: | |
score = item["score"] | |
color = get_color(score) | |
# Apply color only to the score part | |
st.markdown( | |
f"<p><strong>Score:</strong> <span style='color:{color};'>{score:.2f}</span> <br> <strong>Prompt:</strong> {item['prompt']}</p>", | |
unsafe_allow_html=True, | |
) | |
st.write("---") | |
else: | |
st.write("No similar prompts found.") | |
else: | |
st.warning("Please enter a query.") | |
if __name__ == "__main__": | |
args = parse_arguments() | |
main(args.api_url) | |