File size: 2,494 Bytes
3556e6f
 
be043a6
3556e6f
be043a6
3556e6f
 
 
 
 
 
 
 
 
 
 
be043a6
 
3556e6f
 
be043a6
3556e6f
 
 
be043a6
 
 
 
 
 
 
45462fb
3556e6f
45462fb
 
 
 
 
 
 
 
3556e6f
 
be043a6
 
 
 
 
45462fb
be043a6
 
 
 
 
3556e6f
be043a6
 
 
 
 
45462fb
 
 
 
 
 
 
be043a6
 
 
 
 
 
 
 
3556e6f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse

import requests
import streamlit as st


def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Prompt Similarity Finder")
    parser.add_argument(
        "--api_url",
        type=str,
        default="https://lazarr19-prompt-engine.hf.space",
        help="The URL of the FastAPI service",
    )
    return parser.parse_args()


def get_similar_prompts(api_url, query, n):
    """Fetch similar prompts from the FastAPI service."""
    try:
        response = requests.post(
            f"{api_url}/most_similar", json={"query": query, "n": n}
        )
        response.raise_for_status()  # Raise an exception for HTTP errors
        return response.json()
    except requests.RequestException as e:
        st.error(f"Error: {e}")
        return None


def get_color(score):
    """Determine the color based on the score."""
    if score >= 0.8:
        return "green"
    elif score >= 0.5:
        return "orange"
    else:
        return "red"


def main(api_url):
    """Main function to run the Streamlit app."""
    st.title("Prompt Similarity Finder")

    # User input for query
    query = st.text_input("Enter your query:", "")
    n = st.slider(
        "Number of similar prompts to retrieve:", min_value=1, max_value=40, value=5
    )

    if st.button("Find Similar Prompts"):
        if query:
            with st.spinner("Fetching similar prompts..."):
                result = get_similar_prompts(api_url, query, n)
                if result:
                    similar_prompts = result.get("similar_prompts", [])
                    if similar_prompts:
                        st.subheader("Similar Prompts:")
                        for item in similar_prompts:
                            score = item["score"]
                            color = get_color(score)
                            # Apply color only to the score part
                            st.markdown(
                                f"<p><strong>Score:</strong> <span style='color:{color};'>{score:.2f}</span> <br> <strong>Prompt:</strong> {item['prompt']}</p>",
                                unsafe_allow_html=True,
                            )
                            st.write("---")
                    else:
                        st.write("No similar prompts found.")
        else:
            st.warning("Please enter a query.")


if __name__ == "__main__":
    args = parse_arguments()
    main(args.api_url)