|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("satvikag/chatbot") |
|
model = AutoModelForCausalLM.from_pretrained("satvikag/chatbot") |
|
chat_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
except Exception as e: |
|
print(f"Error initializing model: {e}") |
|
chat_pipeline = None |
|
|
|
def ai_vote(poll_title, choices, num_ais): |
|
if chat_pipeline is None: |
|
return "Error: Model not initialized.", "" |
|
|
|
|
|
results = {choice: 0 for choice in choices} |
|
explanations = [] |
|
|
|
|
|
for i in range(num_ais): |
|
input_text = f"Poll Title: {poll_title}\nChoices: {', '.join(choices)}\nChoose the best option and explain why." |
|
try: |
|
response = chat_pipeline(input_text, max_length=150, num_return_sequences=1)[0]['generated_text'] |
|
for choice in choices: |
|
if choice.lower() in response.lower(): |
|
results[choice] += 1 |
|
explanation = response.split("\n", 1)[-1].strip() |
|
explanations.append((choice, explanation)) |
|
break |
|
except Exception as e: |
|
return f"Error: {str(e)}", "" |
|
|
|
|
|
styled_results = f"<h2>{poll_title}</h2>" |
|
styled_results += "<ul>" |
|
for choice, votes in results.items(): |
|
styled_results += f"<li><strong>{choice}</strong>: {votes} votes</li>" |
|
styled_results += "</ul>" |
|
|
|
|
|
styled_results += "<h3>AI Explanations:</h3><ul>" |
|
for choice, explanation in explanations: |
|
styled_results += f"<li><strong>{choice}:</strong> {explanation}</li>" |
|
styled_results += "</ul>" |
|
|
|
return styled_results, explanations |
|
|
|
def gradio_interface(title, choices, num_ais): |
|
try: |
|
choices = [choice.strip() for choice in choices.split(",")] |
|
styled_results, explanations = ai_vote(title, choices, num_ais) |
|
return styled_results, explanations |
|
except Exception as e: |
|
return f"Error: {str(e)}", "" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Textbox(label="Poll Title"), |
|
gr.Textbox(label="Choices (comma-separated)"), |
|
gr.Slider(label="Number of AIs", minimum=1, maximum=10, step=1) |
|
], |
|
outputs=[ |
|
gr.HTML(label="Poll Results"), |
|
gr.Textbox(label="Raw AI Explanations") |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |