Spaces:
Runtime error
Runtime error
import gradio as gr | |
from llm_inference import LLMInferenceNode | |
import random | |
from PIL import Image | |
import io | |
title = """<h1 align="center">SD 3.5 Prompt Generator</h1> | |
<p><center> | |
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a> | |
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a> | |
<p align="center">Generate random prompts using powerful LLMs from Hugging Face and SambaNova.</p> | |
</center></p> | |
""" | |
def create_interface(): | |
llm_node = LLMInferenceNode() | |
with gr.Blocks(theme='bethecloud/storj_theme') as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
custom = gr.Textbox(label="Custom Input Prompt (optional)", lines=3) | |
prompt_types = ["Random", "Long", "Short", "Medium", "OnlyObjects", "NoFigure", "Landscape", "Fantasy"] | |
prompt_type = gr.Dropdown( | |
choices=prompt_types, | |
label="Prompt Type", | |
value="Random", | |
interactive=True | |
) | |
# Add a State component to store the selected prompt type | |
prompt_type_state = gr.State("Random") | |
# Update the function to use State and handle Random option | |
def update_prompt_type(value, state): | |
if value == "Random": | |
new_value = random.choice([t for t in prompt_types if t != "Random"]) | |
print(f"Random prompt type selected: {new_value}") | |
return value, new_value | |
print(f"Updated prompt type: {value}") | |
return value, value | |
# Connect the update_prompt_type function to the prompt_type dropdown | |
prompt_type.change(update_prompt_type, inputs=[prompt_type, prompt_type_state], outputs=[prompt_type, prompt_type_state]) | |
with gr.Column(scale=2): | |
with gr.Accordion("LLM Prompt Generation", open=False): | |
long_talk = gr.Checkbox(label="Long Talk", value=True) | |
compress = gr.Checkbox(label="Compress", value=True) | |
compression_level = gr.Dropdown( | |
choices=["soft", "medium", "hard"], | |
label="Compression Level", | |
value="hard" | |
) | |
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5) | |
# LLM Provider Selection | |
llm_provider = gr.Dropdown( | |
choices=["Hugging Face", "SambaNova"], | |
label="LLM Provider", | |
value="Hugging Face" | |
) | |
api_key = gr.Textbox(label="API Key", type="password", visible=False) | |
model = gr.Dropdown(label="Model", choices=["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3.1-70B-Instruct","mistralai/Mixtral-8x7B-Instruct-v0.1","mistralai/Mistral-7B-Instruct-v0.3"], value="Qwen/Qwen2.5-72B-Instruct") | |
with gr.Row(): | |
# **Single Button for Generating Prompt and Text** | |
generate_button = gr.Button("Generate Prompt") | |
with gr.Row(): | |
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True) | |
# Updated Models based on provider | |
def update_model_choices(provider): | |
provider_models = { | |
"Hugging Face": [ | |
"Qwen/Qwen2.5-72B-Instruct", | |
"meta-llama/Meta-Llama-3.1-70B-Instruct", | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"mistralai/Mistral-7B-Instruct-v0.3" | |
], | |
"SambaNova": [ | |
"Meta-Llama-3.1-70B-Instruct", | |
"Meta-Llama-3.1-405B-Instruct", | |
"Meta-Llama-3.1-8B-Instruct" | |
], | |
} | |
models = provider_models.get(provider, []) | |
return gr.Dropdown(choices=models, value=models[0] if models else "") | |
def update_api_key_visibility(provider): | |
return gr.update(visible=False) # No API key required for selected providers | |
llm_provider.change( | |
update_model_choices, | |
inputs=[llm_provider], | |
outputs=[model] | |
) | |
llm_provider.change( | |
update_api_key_visibility, | |
inputs=[llm_provider], | |
outputs=[api_key] | |
) | |
# **Unified Function to Generate Prompt and Text** | |
def generate_random_prompt_with_llm(custom_input, prompt_type, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected, prompt_type_state): | |
try: | |
# Step 1: Generate Prompt | |
dynamic_seed = random.randint(0, 1000000) | |
# Update prompt_type if it's "Random" | |
if prompt_type == "Random": | |
prompt_type = random.choice([t for t in prompt_types if t != "Random"]) | |
print(f"Random prompt type selected: {prompt_type}") | |
if custom_input and custom_input.strip(): | |
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input) | |
print(f"Using Custom Input Prompt.") | |
else: | |
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, f"Create a random prompt based on the '{prompt_type}' type.") | |
print(f"No Custom Input Prompt provided. Generated prompt based on prompt_type: {prompt_type}") | |
print(f"Generated Prompt: {prompt}") | |
# Step 2: Generate Text with LLM | |
poster = False # Set a default value or modify as needed | |
result = llm_node.generate( | |
input_text=prompt, | |
long_talk=long_talk, | |
compress=compress, | |
compression_level=compression_level, | |
poster=poster, | |
prompt_type=prompt_type, # Use the updated prompt_type here | |
custom_base_prompt=custom_base_prompt, | |
provider=provider, | |
api_key=api_key, | |
model=model_selected | |
) | |
print(f"Generated Text: {result}") | |
return result | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
return f"Error occurred while processing the request: {str(e)}" | |
# **Connect the Unified Function to the Single Button** | |
generate_button.click( | |
generate_random_prompt_with_llm, | |
inputs=[custom, prompt_type, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model, prompt_type_state], | |
outputs=[text_output], | |
api_name="generate_random_prompt_with_llm" | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(share=True) |