File size: 5,499 Bytes
69acc93
7e2c859
b38913b
03ab4d3
69acc93
 
 
b38913b
69acc93
 
 
2d49f29
6ffc1bc
69acc93
 
 
 
7e2c859
69acc93
 
 
 
 
 
 
 
 
 
 
 
 
 
a3e812d
03ab4d3
 
 
 
 
 
 
 
 
 
 
 
69acc93
 
 
 
 
 
3ba09e8
 
69acc93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03ab4d3
69acc93
 
 
 
 
 
 
 
 
 
 
 
 
 
03ab4d3
 
69acc93
 
 
 
 
 
 
 
 
 
 
3ba09e8
69acc93
 
 
 
3ba09e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69acc93
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
import subprocess
import tempfile, time
import shutil
import os
import spaces

from transformers import T5ForConditionalGeneration, T5Tokenizer
import os

print ("starting the app.")

def download_t5_model(model_id, save_directory):
    # Modelin tokenizer'ını ve modeli indir
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    snapshot_download(repo_id="DeepFloyd/t5-v1_1-xxl",local_dir=save_directory, local_dir_use_symlinks=False)

# Model ID ve kaydedilecek dizin
model_id = "DeepFloyd/t5-v1_1-xxl"
save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl"

# Modeli indir
download_t5_model(model_id, save_directory)

def download_model(repo_id, model_name):
    model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
    return model_path

import glob

@spaces.GPU(duration=500)
def run_model(temp_config_path, ckpt_path):
    start_time = time.time()  # Record the start time
    cmd = [
        "torchrun", "--standalone", "--nproc_per_node", "1",
        "scripts/inference.py", temp_config_path,
        "--ckpt-path", ckpt_path
    ]
    subprocess.run(cmd)
    end_time = time.time()  # Record the end time
    execution_time = end_time - start_time  # Calculate the execution time
    print(f"Model Execution time: {execution_time} seconds")

def run_inference(model_name, prompt_text):
    repo_id = "hpcai-tech/Open-Sora"
    
    # Map model names to their respective configuration files
    config_mapping = {
        "OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
        "OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
        "OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/16x512x512.py"
    }
    
    config_path = config_mapping[model_name]
    ckpt_path = download_model(repo_id, model_name)

    # Save prompt_text to a temporary text file
    prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w')
    prompt_file.write(prompt_text)
    prompt_file.close()

    with open(config_path, 'r') as file:
        config_content = file.read()
    config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
    
    with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file:
        temp_file.write(config_content)
        temp_config_path = temp_file.name

    run_model(temp_config_path, ckpt_path)

    save_dir = "./outputs/samples/"  # Örneğin, inference.py tarafından kullanılan kayıt dizini
    list_of_files = glob.glob(f'{save_dir}/*')
    if list_of_files:
        latest_file = max(list_of_files, key=os.path.getctime)
        return latest_file
    else:
        print("No files found in the output directory.")
        return None

    # Clean up the temporary files
    os.remove(temp_file.name)
    os.remove(prompt_file.name)



def main():
    gr.Interface(
        fn=run_inference,
        inputs=[
            gr.Dropdown(choices=[
                "OpenSora-v1-16x256x256.pth",
                "OpenSora-v1-HQ-16x256x256.pth",
                "OpenSora-v1-HQ-16x512x512.pth"
            ], 
            value="OpenSora-v1-16x256x256.pth",
            label="Model Selection"),
            gr.Textbox(label="Prompt Text", value="iron man riding a skateboard in new york city")
        ],
        outputs=gr.Video(label="Output Video"),
        title="Open-Sora Inference",
        description="Run Open-Sora Inference with Custom Parameters",
        # examples=[
        #     ["OpenSora-v1-HQ-16x256x256.pth", "iron man riding a skateboard in new york city"],
        #     ["OpenSora-v1-16x256x256.pth", "a man is skiing down a snowy mountain. a drone shot from above. an avalanche is chasing him from behind."],
        #     ["OpenSora-v1-16x256x256.pth", "Extreme close up of a 24 year old woman’s eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field, vivid colors, cinematic"],
        # ],
        article = """
# Examples

| Model                        | Description                                                                                                          | Video Player Embedding                                  |
|------------------------------|----------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| OpenSora-v1-HQ-16x256x256.pth | Iron Man riding a skateboard in New York City                                                                        | ![ironman](https://github.com/sandeshrajbhandari/open-sora-examples/assets/12326258/8173e37f-6405-44f3-aaaa-fafc88187933) |
| OpenSora-v1-16x256x256.pth   | A man is skiing down a snowy mountain. A drone shot from above. An avalanche is chasing him from behind.            | ![skiing](https://github.com/sandeshrajbhandari/open-sora-examples/assets/12326258/d2cab73a-a77e-4e0b-a80e-668e252b6b6a) |
| OpenSora-v1-16x256x256.pth   | Extreme close-up of a 24-year-old woman’s eye blinking, standing in Marrakech during magic hour, cinematic film shot in 70mm, depth of field, vivid colors, cinematic | ![woman](https://github.com/sandeshrajbhandari/open-sora-examples/assets/12326258/38322939-f7bf-4f72-8a5e-ccc427970afc) |

        """
    ).launch()

if __name__ == "__main__":
    main()