Spaces:
Runtime error
Runtime error
File size: 5,851 Bytes
e1089fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
from llm_inference import LLMInferenceNode
import random
title = """<h1 align="center">Random Prompt Generator</h1>
<p><center>
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a>
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a>
<a href="https://github.com/dagthomas/comfyui_dagthomas" target="_blank">[comfyui_dagthomas]</a>
<p align="center">Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.</p>
</center></p>
"""
# Global variable to store selected prompt type
selected_prompt_type = "Long" # Default value
def create_interface():
llm_node = LLMInferenceNode()
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Basic Settings"):
custom = gr.Textbox(label="Custom Input Prompt (optional)")
with gr.Accordion("Prompt Generation Options", open=False):
prompt_type = gr.Dropdown(
choices=["Long", "Short", "Medium", "Long"],
label="Prompt Type",
value="Long",
interactive=True
)
# Function to update the selected prompt type
def update_prompt_type(value):
global selected_prompt_type
selected_prompt_type = value
print(f"Updated prompt type: {selected_prompt_type}")
return value
# Connect the update_prompt_type function to the prompt_type dropdown
prompt_type.change(update_prompt_type, inputs=[prompt_type], outputs=[prompt_type])
with gr.Column(scale=2):
generate_button = gr.Button("Generate Prompt")
with gr.Accordion("Generated Prompt", open=True):
output = gr.Textbox(label="Generated Prompt", lines=4, show_copy_button=True)
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True)
with gr.Column(scale=2):
with gr.Accordion("""LLM Prompt Generation""", open=False):
long_talk = gr.Checkbox(label="Long Talk", value=True)
compress = gr.Checkbox(label="Compress", value=True)
compression_level = gr.Dropdown(
choices=["soft", "medium", "hard"],
label="Compression Level",
value="hard"
)
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
# LLM Provider Selection
llm_provider = gr.Dropdown(
choices=["Hugging Face", "Groq", "SambaNova"],
label="LLM Provider",
value="Hugging Face"
)
api_key = gr.Textbox(label="API Key", type="password", visible=False)
model = gr.Dropdown(label="Model", choices=[], value="")
generate_text_button = gr.Button("Generate Prompt with LLM")
text_output = gr.Textbox(label="Generated Text", lines=10, show_copy_button=True)
# Initialize Models based on provider
def update_model_choices(provider):
provider_models = {
"Hugging Face": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "another-model-hf"],
"Groq": ["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it"],
"SambaNova": ["Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", "Meta-Llama-3.1-8B-Instruct"],
}
models = provider_models.get(provider, [])
return gr.Dropdown.update(choices=models, value=models[0] if models else "")
def update_api_key_visibility(provider):
return gr.update(visible=False) # No API key required for selected providers
llm_provider.change(update_model_choices, inputs=[llm_provider], outputs=[model])
llm_provider.change(update_api_key_visibility, inputs=[llm_provider], outputs=[api_key])
# Generate Prompt Function
def generate_prompt(prompt_type, custom_input):
dynamic_seed = random.randint(0, 1000000)
result = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input)
return result
generate_button.click(
generate_prompt,
inputs=[prompt_type, custom],
outputs=[output]
)
# Generate Text with LLM
def generate_text_with_llm(output_prompt, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected):
global selected_prompt_type
result = llm_node.generate(
input_text=output_prompt,
long_talk=long_talk,
compress=compress,
compression_level=compression_level,
prompt_type=selected_prompt_type,
custom_base_prompt=custom_base_prompt,
provider=provider,
api_key=api_key,
model=model_selected
)
selected_prompt_type = "Long"
return result
generate_text_button.click(
generate_text_with_llm,
inputs=[output, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model],
outputs=[text_output],
api_name="generate_text"
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch() |