writerUI / app.py
ctn8176's picture
Update app.py
a7f4470 verified
raw
history blame
1.82 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
from datasets import load_dataset
model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def get_movie_info(movie_title):
# Load the IMDb dataset
imdb = load_dataset("imdb")
# Search for the movie in the IMDb dataset
results = imdb['title'].filter(lambda x: movie_title.lower() in x.lower())
# Check if any results are found
if len(results) > 0:
movie = results[0]
return f"Title: {movie['title']}, Year: {movie['year']}, Genre: {', '.join(movie['genre'])}"
else:
return "Movie not found"
def generate_response(prompt):
input_text_template = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
f"USER: {prompt} "
"ASSISTANT:"
)
# Call the get_movie_info function
movie_info = get_movie_info(prompt)
# Concatenate the movie info with the input template
input_text_template += f" Movie Info: {movie_info}"
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
gen_conf = {
"top_k": 20,
"max_length": 200,
"temperature": 0.6,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
}
output = model.generate(**model_inputs, **gen_conf)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", live=True)
iface.launch()