Spaces:
Runtime error
Runtime error
import gradio as gr | |
from llm_inference import LLMInferenceNode | |
import random | |
title = """<h1 align="center">Random Prompt Generator</h1> | |
<p><center> | |
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a> | |
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a> | |
<a href="https://github.com/dagthomas/comfyui_dagthomas" target="_blank">[comfyui_dagthomas]</a> | |
<p align="center">Generate random prompts using powerful LLMs from Hugging Face, Groq, and SambaNova.</p> | |
</center></p> | |
""" | |
# Global variable to store selected prompt type | |
selected_prompt_type = "Long" # Default value | |
def create_interface(): | |
llm_node = LLMInferenceNode() | |
with gr.Blocks(theme='bethecloud/storj_theme') as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Accordion("Basic Settings"): | |
custom = gr.Textbox(label="Custom Input Prompt (optional)") | |
with gr.Accordion("Prompt Generation Options", open=False): | |
prompt_type = gr.Dropdown( | |
choices=["Long", "Short", "Medium", "Long"], | |
label="Prompt Type", | |
value="Long", | |
interactive=True | |
) | |
# Function to update the selected prompt type | |
def update_prompt_type(value): | |
global selected_prompt_type | |
selected_prompt_type = value | |
print(f"Updated prompt type: {selected_prompt_type}") | |
return value | |
# Connect the update_prompt_type function to the prompt_type dropdown | |
prompt_type.change(update_prompt_type, inputs=[prompt_type], outputs=[prompt_type]) | |
with gr.Column(scale=2): | |
generate_button = gr.Button("Generate Prompt") | |
with gr.Accordion("Generated Prompt", open=True): | |
output = gr.Textbox(label="Generated Prompt", lines=4, show_copy_button=True) | |
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True) | |
with gr.Column(scale=2): | |
with gr.Accordion("""LLM Prompt Generation""", open=False): | |
long_talk = gr.Checkbox(label="Long Talk", value=True) | |
compress = gr.Checkbox(label="Compress", value=True) | |
compression_level = gr.Dropdown( | |
choices=["soft", "medium", "hard"], | |
label="Compression Level", | |
value="hard" | |
) | |
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5) | |
# LLM Provider Selection | |
llm_provider = gr.Dropdown( | |
choices=["Hugging Face", "Groq", "SambaNova"], | |
label="LLM Provider", | |
value="Hugging Face" | |
) | |
api_key = gr.Textbox(label="API Key", type="password", visible=False) | |
model = gr.Dropdown(label="Model", choices=[], value="") | |
generate_text_button = gr.Button("Generate Prompt with LLM") | |
text_output = gr.Textbox(label="Generated Text", lines=10, show_copy_button=True) | |
# Initialize Models based on provider | |
def update_model_choices(provider): | |
provider_models = { | |
"Hugging Face": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "another-model-hf"], | |
"Groq": ["llama-3.1-70b-versatile", "mixtral-8x7b-32768", "gemma2-9b-it"], | |
"SambaNova": ["Meta-Llama-3.1-70B-Instruct", "Meta-Llama-3.1-405B-Instruct", "Meta-Llama-3.1-8B-Instruct"], | |
} | |
models = provider_models.get(provider, []) | |
return gr.Dropdown.update(choices=models, value=models[0] if models else "") | |
def update_api_key_visibility(provider): | |
return gr.update(visible=False) # No API key required for selected providers | |
llm_provider.change(update_model_choices, inputs=[llm_provider], outputs=[model]) | |
llm_provider.change(update_api_key_visibility, inputs=[llm_provider], outputs=[api_key]) | |
# Generate Prompt Function | |
def generate_prompt(prompt_type, custom_input): | |
dynamic_seed = random.randint(0, 1000000) | |
result = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input) | |
return result | |
generate_button.click( | |
generate_prompt, | |
inputs=[prompt_type, custom], | |
outputs=[output] | |
) | |
# Generate Text with LLM | |
def generate_text_with_llm(output_prompt, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected): | |
global selected_prompt_type | |
result = llm_node.generate( | |
input_text=output_prompt, | |
long_talk=long_talk, | |
compress=compress, | |
compression_level=compression_level, | |
prompt_type=selected_prompt_type, | |
custom_base_prompt=custom_base_prompt, | |
provider=provider, | |
api_key=api_key, | |
model=model_selected | |
) | |
selected_prompt_type = "Long" | |
return result | |
generate_text_button.click( | |
generate_text_with_llm, | |
inputs=[output, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model], | |
outputs=[text_output], | |
api_name="generate_text" | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |