Samurai719214's picture
Update app.py
72a9cf5 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load your hosted model and tokenizer from Hugging Face.
model_name = "Samurai719214/gptneo-mythology-storyteller"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Use GPU if available.
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Store conversation history
conversation_history = []
def generate_full_story(excerpt: str) -> str:
"""
Given an incomplete story excerpt, generate the complete story including Parv, Key Event, Section, and continuation.
"""
if not excerpt.strip():
return "Please enter a valid story excerpt."
# Tokenize the user-provided excerpt.
encoded_input = tokenizer(excerpt, return_tensors = "pt")
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
# Generate text with controlled parameters.
output = model.generate(
encoded_input["input_ids"],
attention_mask = encoded_input["attention_mask"],
max_new_tokens = 200,
do_sample = True,
temperature = 0.1,
top_p = 0.95,
no_repeat_ngram_size = 2,
return_dict_in_generate = True
)
# Decode generated text.
generated_text = tokenizer.decode(output.sequences[0], skip_special_tokens = True)
# Append to conversation history
conversation_history.append((excerpt, generated_text))
return generated_text
def get_conversation_history():
"""Displays conversation history."""
if not conversation_history:
return "No conversations started."
return "\n\n".join([f"**User-** {inp}\n**AI-** {out}" for inp, out in conversation_history])
def clear_conversation():
"""Clears the conversation history."""
if not conversation_history:
return "No conversations started."
else:
conversation_history.clear()
return "Conversations deleted!"
# Build the Gradio interface.
with gr.Blocks() as interface:
gr.Markdown("# 🏺 Mythology Storyteller")
gr.Markdown("Enter a phrase from a chapter of your choice (please include Parv, key event, and section for a better answer). The model will generate the summary of the respective chapter.")
with gr.Row():
user_input = gr.Textbox(lines = 5, label = "Incomplete story excerpt", placeholder = "Enter an excerpt from the Mahabharata here...")
output_text = gr.Textbox(label = "Chapter summary")
generate_btn = gr.Button("Generate Story")
generate_btn.click(fn = generate_full_story, inputs = user_input, outputs = output_text)
with gr.Row():
history_display = gr.Textbox(label = "Conversation History", interactive = False)
show_history_btn = gr.Button("Show Conversation History")
show_history_btn.click(fn = get_conversation_history, outputs = history_display)
clear_btn = gr.Button("Clear Conversation")
clear_btn.click(fn = clear_conversation, outputs = history_display)
# Launch the Gradio app.
interface.launch()