Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
import requests | |
from typing import Optional, Dict, Any, List | |
import json | |
import tempfile | |
from PIL import Image | |
from groq import Groq | |
from openai import OpenAI | |
import spaces | |
class VideoLLMInferenceNode: | |
def __init__(self): | |
""" | |
Initialize the VideoLLMInferenceNode without VLM captioning dependency | |
""" | |
self.sambanova_api_key = os.environ.get("SAMBANOVA_API_KEY", "") | |
self.groq_api_key = os.environ.get("GROQ_API_KEY", "") | |
# Initialize API clients if keys are available | |
if self.groq_api_key: | |
self.groq_client = Groq(api_key=self.groq_api_key) | |
else: | |
self.groq_client = None | |
if self.sambanova_api_key: | |
self.sambanova_client = OpenAI( | |
api_key=self.sambanova_api_key, | |
base_url="https://api.sambanova.ai/v1", | |
) | |
else: | |
self.sambanova_client = None | |
def analyze_image(self, image_path: str, question: Optional[str] = None) -> str: | |
""" | |
Analyze an image using VLM model directly | |
Args: | |
image_path: Path to the image file | |
question: Optional question to ask about the image | |
Returns: | |
str: Analysis result | |
""" | |
if not image_path: | |
return "Please upload an image." | |
if not question or question.strip() == "": | |
question = "Describe this image in detail." | |
try: | |
# Import and use VLMCaptioning within this GPU-scoped function | |
from app import get_vlm_captioner | |
vlm = get_vlm_captioner() | |
return vlm.describe_image(image_path, question) | |
except Exception as e: | |
return f"Error analyzing image: {str(e)}" | |
def analyze_video(self, video_path: str) -> str: | |
""" | |
Analyze a video using VLM model directly | |
Args: | |
video_path: Path to the video file | |
Returns: | |
str: Analysis result | |
""" | |
if not video_path: | |
return "Please upload a video." | |
try: | |
# Import and use VLMCaptioning within this GPU-scoped function | |
from app import get_vlm_captioner | |
vlm = get_vlm_captioner() | |
return vlm.describe_video(video_path) | |
except Exception as e: | |
return f"Error analyzing video: {str(e)}" | |
def generate_video_prompt( | |
self, | |
concept: str, | |
style: str = "Simple", | |
camera_style: str = "None", | |
camera_direction: str = "None", | |
pacing: str = "None", | |
special_effects: str = "None", | |
custom_elements: str = "", | |
provider: str = "SambaNova", | |
model: str = "Meta-Llama-3.1-70B-Instruct", | |
prompt_length: str = "Medium" | |
) -> str: | |
""" | |
Generate a video prompt using the specified LLM provider | |
Args: | |
concept: Core concept for the video | |
style: Video style | |
camera_style: Camera style | |
camera_direction: Camera direction | |
pacing: Pacing rhythm | |
special_effects: Special effects approach | |
custom_elements: Custom technical elements | |
provider: LLM provider (SambaNova or Groq) | |
model: Model name | |
prompt_length: Desired prompt length | |
Returns: | |
str: Generated video prompt | |
""" | |
if not concept: | |
return "Please enter a concept for the video." | |
# Build the prompt | |
system_message = """You are a professional video prompt generator. Your task is to create detailed, technical, and creative video prompts based on user inputs. | |
The prompts should be suitable for text-to-video AI models and include specific technical details that match the requested style, camera movement, pacing, and effects. | |
Focus on creating high-quality, cohesive prompts that could be used to generate impressive AI videos.""" | |
# Set prompt length guidelines | |
length_guide = { | |
"Short": "Create a concise prompt of 2-3 sentences.", | |
"Medium": "Create a detailed prompt of 4-6 sentences.", | |
"Long": "Create an extensive prompt with 7-10 sentences covering all details." | |
} | |
# Put together options for the prompt | |
options = [] | |
if style and style != "None": | |
options.append(f"Style: {style}") | |
if camera_style and camera_style != "None": | |
options.append(f"Camera Movement Style: {camera_style}") | |
if camera_direction and camera_direction != "None": | |
options.append(f"Camera Direction: {camera_direction}") | |
if pacing and pacing != "None": | |
options.append(f"Pacing Rhythm: {pacing}") | |
if special_effects and special_effects != "None": | |
options.append(f"Special Effects: {special_effects}") | |
if custom_elements: | |
options.append(f"Custom Elements: {custom_elements}") | |
options_text = "\n".join(options) | |
user_message = f"""Create a video prompt based on the following concept and specifications: | |
CONCEPT: {concept} | |
SPECIFICATIONS: | |
{options_text} | |
{length_guide.get(prompt_length, length_guide["Medium"])} | |
The prompt should be detailed and technical, specifically mentioning camera angles, movements, lighting, transitions, and other visual elements that would create an impressive AI-generated video. | |
""" | |
# Call the appropriate API based on provider | |
try: | |
if provider == "SambaNova": | |
if self.sambanova_client: | |
return self._call_sambanova_client(system_message, user_message, model) | |
else: | |
return self._call_sambanova_api(system_message, user_message, model) | |
elif provider == "Groq": | |
if self.groq_client: | |
return self._call_groq_client(system_message, user_message, model) | |
else: | |
return self._call_groq_api(system_message, user_message, model) | |
else: | |
return "Unsupported provider. Please select SambaNova or Groq." | |
except Exception as e: | |
return f"Error generating prompt: {str(e)}" | |
def _call_sambanova_client(self, system_message: str, user_message: str, model: str) -> str: | |
"""Call the SambaNova API using the client library""" | |
try: | |
chat_completion = self.sambanova_client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
) | |
return chat_completion.choices[0].message.content | |
except Exception as e: | |
return f"Error from SambaNova API: {str(e)}" | |
def _call_sambanova_api(self, system_message: str, user_message: str, model: str) -> str: | |
"""Call the SambaNova API using direct HTTP requests""" | |
if not self.sambanova_api_key: | |
return "SambaNova API key not configured. Please set the SAMBANOVA_API_KEY environment variable." | |
api_url = "https://api.sambanova.ai/api/v1/chat/completions" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {self.sambanova_api_key}" | |
} | |
payload = { | |
"model": model, | |
"messages": [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
} | |
response = requests.post(api_url, headers=headers, json=payload) | |
if response.status_code == 200: | |
result = response.json() | |
return result.get("choices", [{}])[0].get("message", {}).get("content", "No content returned") | |
else: | |
return f"Error from SambaNova API: {response.status_code} - {response.text}" | |
def _call_groq_client(self, system_message: str, user_message: str, model: str) -> str: | |
"""Call the Groq API using the client library""" | |
try: | |
chat_completion = self.groq_client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
) | |
return chat_completion.choices[0].message.content | |
except Exception as e: | |
return f"Error from Groq API: {str(e)}" | |
def _call_groq_api(self, system_message: str, user_message: str, model: str) -> str: | |
"""Call the Groq API using direct HTTP requests""" | |
if not self.groq_api_key: | |
return "Groq API key not configured. Please set the GROQ_API_KEY environment variable." | |
api_url = "https://api.groq.com/openai/v1/chat/completions" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {self.groq_api_key}" | |
} | |
payload = { | |
"model": model, | |
"messages": [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
} | |
response = requests.post(api_url, headers=headers, json=payload) | |
if response.status_code == 200: | |
result = response.json() | |
return result.get("choices", [{}])[0].get("message", {}).get("content", "No content returned") | |
else: | |
return f"Error from Groq API: {response.status_code} - {response.text}" |