File size: 3,356 Bytes
d903275 0a2c880 2ff483e 727c299 2ff483e d903275 2ff483e 7b6b4a2 727c299 7b6b4a2 727c299 7b6b4a2 727c299 97b7fcf 8139f95 5bd8a1a 8139f95 727c299 8139f95 727c299 5bd8a1a 727c299 5bd8a1a 7b6b4a2 0a2c880 3b25749 727c299 5bd8a1a 7b6b4a2 0a2c880 5bf06ec 0a2c880 8139f95 0a2c880 8139f95 5bd8a1a 8139f95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import requests
model_name = "Writer/palmyra-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
def get_movie_info(movie_title):
api_key = "20e959f0f28e6b3e3de49c50f358538a"
search_url = f"https://api.themoviedb.org/3/search/movie"
# Make a search query to TMDb
params = {
"api_key": api_key,
"query": movie_title,
"language": "en-US",
"page": 1,
}
try:
search_response = requests.get(search_url, params=params)
search_data = search_response.json()
# Check if any results are found
if search_data.get("results"):
movie_id = search_data["results"][0]["id"]
# Fetch detailed information using the movie ID
details_url = f"https://api.themoviedb.org/3/movie/{movie_id}"
details_params = {
"api_key": api_key,
"language": "en-US",
}
details_response = requests.get(details_url, params=details_params)
details_data = details_response.json()
# Extract relevant information
title = details_data.get("title", "Unknown Title")
year = details_data.get("release_date", "Unknown Year")[:4]
genre = ", ".join(genre["name"] for genre in details_data.get("genres", []))
tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}"
# poster_path = details_data.get("poster_path")
# Convert poster_path to a complete image URL
# image_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else ""
return f"Title: {title}, Year: {year}, Genre: {genre}\nFind more info here: {tmdb_link}"
else:
return "Movie not found", ""
except Exception as e:
return f"Error: {e}", ""
def generate_response(prompt):
input_text_template = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
f"USER: {prompt} "
"ASSISTANT:"
)
# Call the get_movie_info function to enrich the response
movie_info, image_url = get_movie_info(prompt)
# Concatenate the movie info with the input template
input_text_template += f" Movie Info: {movie_info}"
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device)
gen_conf = {
"top_k": 20,
"max_length": 20,
"temperature": 0.6,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id,
}
output = model.generate(**model_inputs, **gen_conf)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
return f"Movie Info:\n{movie_info}\n\nGenerated Response:\n{generated_text}"
# Define chat function for gr.ChatInterface
def chat_function(message, history):
response = generate_response(message)
history.append([message, response])
return response
# Create Gradio Chat Interface
chat_interface = gr.ChatInterface(chat_function)
chat_interface.launch() |