writerUI / app.py
ctn8176's picture
Update app.py
1592e82 verified
raw
history blame
4.19 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests
model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def get_movie_info(movie_title):
api_key = "20e959f0f28e6b3e3de49c50f358538a"
search_url = f"https://api.themoviedb.org/3/search/movie"
# Make a search query to TMDb
params = {
"api_key": api_key,
"query": movie_title,
"language": "en-US",
"page": 1,
}
try:
search_response = requests.get(search_url, params=params)
search_data = search_response.json()
# Check if any results are found
if search_data.get("results"):
movie_id = search_data["results"][0]["id"]
# Fetch detailed information using the movie ID
details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
details_params = {
"api_key": api_key,
"language": "en-US",
}
details_response = requests.get(details_url, params=details_params)
details_data = details_response.json()
# Extract relevant information
title = details_data.get("title", "Unknown Title")
year = details_data.get("release_date", "Unknown Year")[:4]
genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
# poster_path = details_data.get("poster_path")
# Convert poster_path to a complete image URL
# image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else ""
return {
"title": title,
"year": year,
"genre": genre,
"tmdb_link": tmdb_link,
}
else:
return {"error": "Movie not found"}
except Exception as e:
return {"error": f"Error: {e}"}
def process_image(movie_info):
# Process the image, return image_url
# For now, let's just return a placeholder
return "https://via.placeholder.com/150"
def generate_response(prompt):
input_text_template = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
f"USER: {prompt} "
"ASSISTANT:"
)
# Call the get_movie_info function to enrich the response
movie_info = get_movie_info(prompt)
if "error" in movie_info:
return f"Error: {movie_info['error']}", None
# Process the image separately
image_url = process_image(movie_info)
# Concatenate the movie info with the input template
input_text_template += (
f" Movie Info: Title: {movie_info['title']}, Year: {movie_info['year']}, "
f"Genre: {movie_info['genre']}\nFind more info here: {movie_info['tmdb_link']}"
)
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
gen_conf = {
"top_k": 20,
"max_length": 20,
"temperature": 0.6,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
}
output = model.generate(**model_inputs, **gen_conf)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return (
f"Movie Info: {movie_info['title']}, {movie_info['year']}, {movie_info['genre']}\n\n" \
f"Generated Response: {generated_text[:100]}", # Truncate to 100 characters
image_url
)
# Define chat function for gr.ChatInterface
def chat_function(message, history):
response_text, response_image = generate_response(message)
history.append([message, response_text])
history.append([message, response_image]) # Separate history for image
return response_text, response_image
# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(chat_function)
chat_interface.launch(share=True)