|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import gradio as gr |
|
import requests |
|
|
|
model_name = "Writer/palmyra-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) |
|
|
|
def get_movie_info(movie_title): |
|
api_key = "20e959f0f28e6b3e3de49c50f358538a" |
|
search_url = f"https://api.themoviedb.org/3/search/movie" |
|
|
|
|
|
params = { |
|
"api_key": api_key, |
|
"query": movie_title, |
|
"language": "en-US", |
|
"page": 1, |
|
} |
|
|
|
try: |
|
search_response = requests.get(search_url, params=params) |
|
search_data = search_response.json() |
|
|
|
|
|
if search_data.get("results"): |
|
movie_id = search_data["results"][0]["id"] |
|
|
|
|
|
details_url = f"https://api.themoviedb.org/3/movie/{movie_id}" |
|
details_params = { |
|
"api_key": api_key, |
|
"language": "en-US", |
|
} |
|
|
|
details_response = requests.get(details_url, params=details_params) |
|
details_data = details_response.json() |
|
|
|
|
|
title = details_data.get("title", "Unknown Title") |
|
year = details_data.get("release_date", "Unknown Year")[:4] |
|
genre = ", ".join(genre["name"] for genre in details_data.get("genres", [])) |
|
tmdb_link = f"https://www.themoviedb.org/movie/{movie_id}" |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
"title": title, |
|
"year": year, |
|
"genre": genre, |
|
"tmdb_link": tmdb_link, |
|
} |
|
|
|
else: |
|
return {"error": "Movie not found"} |
|
|
|
except Exception as e: |
|
return {"error": f"Error: {e}"} |
|
|
|
def process_image(movie_info): |
|
|
|
|
|
return "https://via.placeholder.com/150" |
|
|
|
def generate_response(prompt): |
|
input_text_template = ( |
|
"A chat between a curious user and an artificial intelligence assistant. " |
|
"The assistant gives helpful, detailed, and polite answers to the user's questions. " |
|
f"USER: {prompt} " |
|
"ASSISTANT:" |
|
) |
|
|
|
|
|
movie_info = get_movie_info(prompt) |
|
|
|
if "error" in movie_info: |
|
return f"Error: {movie_info['error']}", None |
|
|
|
|
|
image_url = process_image(movie_info) |
|
|
|
|
|
input_text_template += ( |
|
f" Movie Info: Title: {movie_info['title']}, Year: {movie_info['year']}, " |
|
f"Genre: {movie_info['genre']}\nFind more info here: {movie_info['tmdb_link']}" |
|
) |
|
|
|
model_inputs = tokenizer(input_text_template, return_tensors="pt").to(device) |
|
|
|
gen_conf = { |
|
"top_k": 20, |
|
"max_length": 20, |
|
"temperature": 0.6, |
|
"do_sample": True, |
|
"eos_token_id": tokenizer.eos_token_id, |
|
} |
|
|
|
output = model.generate(**model_inputs, **gen_conf) |
|
|
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
return ( |
|
f"Movie Info: {movie_info['title']}, {movie_info['year']}, {movie_info['genre']}\n\n" \ |
|
f"Generated Response: {generated_text[:100]}", |
|
image_url |
|
) |
|
|
|
|
|
def chat_function(message, history): |
|
response_text, response_image = generate_response(message) |
|
history.append([message, response_text]) |
|
history.append([message, response_image]) |
|
return response_text, response_image |
|
|
|
|
|
chat_interface = gr.ChatInterface(chat_function) |
|
chat_interface.launch(share=True) |