File size: 4,185 Bytes
d903275 0a2c880 2ff483e 727c299 2ff483e d903275 2ff483e 7b6b4a2 727c299 7b6b4a2 727c299 7b6b4a2 727c299 97b7fcf dc5958a 5bd8a1a dc5958a 727c299 ed2673b 727c299 ed2673b 5bd8a1a 727c299 ed2673b 7b6b4a2 0a2c880 3b25749 727c299 dc5958a 7b6b4a2 ed2673b 7b6b4a2 ed2673b 7b6b4a2 0a2c880 5bf06ec 0a2c880 b99a30a 1592e82 0a2c880 8139f95 ed2673b 5bd8a1a 8139f95 0cdb0a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests
model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def get_movie_info(movie_title):
api_key = "20e959f0f28e6b3e3de49c50f358538a"
search_url = f"https://api.themoviedb.org/3/search/movie"
# Make a search query to TMDb
params = {
"api_key": api_key,
"query": movie_title,
"language": "en-US",
"page": 1,
}
try:
search_response = requests.get(search_url, params=params)
search_data = search_response.json()
# Check if any results are found
if search_data.get("results"):
movie_id = search_data["results"][0]["id"]
# Fetch detailed information using the movie ID
details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
details_params = {
"api_key": api_key,
"language": "en-US",
}
details_response = requests.get(details_url, params=details_params)
details_data = details_response.json()
# Extract relevant information
title = details_data.get("title", "Unknown Title")
year = details_data.get("release_date", "Unknown Year")[:4]
genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
# poster_path = details_data.get("poster_path")
# Convert poster_path to a complete image URL
# image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else ""
return {
"title": title,
"year": year,
"genre": genre,
"tmdb_link": tmdb_link,
}
else:
return {"error": "Movie not found"}
except Exception as e:
return {"error": f"Error: {e}"}
def process_image(movie_info):
# Process the image, return image_url
# For now, let's just return a placeholder
return "https://via.placeholder.com/150"
def generate_response(prompt):
input_text_template = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
f"USER: {prompt} "
"ASSISTANT:"
)
# Call the get_movie_info function to enrich the response
movie_info = get_movie_info(prompt)
if "error" in movie_info:
return f"Error: {movie_info['error']}", None
# Process the image separately
image_url = process_image(movie_info)
# Concatenate the movie info with the input template
input_text_template += (
f" Movie Info: Title: {movie_info['title']}, Year: {movie_info['year']}, "
f"Genre: {movie_info['genre']}\nFind more info here: {movie_info['tmdb_link']}"
)
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
gen_conf = {
"top_k": 20,
"max_length": 20,
"temperature": 0.6,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
}
output = model.generate(**model_inputs, **gen_conf)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return (
f"Movie Info: {movie_info['title']}, {movie_info['year']}, {movie_info['genre']}\n\n" \
f"Generated Response: {generated_text[:100]}", # Truncate to 100 characters
image_url
)
# Define chat function for gr.ChatInterface
def chat_function(message, history):
response_text, response_image = generate_response(message)
history.append([message, response_text])
history.append([message, response_image]) # Separate history for image
return response_text, response_image
# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(chat_function)
chat_interface.launch(share=True) |