File size: 4,057 Bytes
85baff2
789acc7
 
85baff2
fd950ef
66011b0
789acc7
fd950ef
 
 
 
 
 
 
85baff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93f8b15
 
85baff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import gradio as gr
import torch
import torch.distributed as dist
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings

# disable some warnings
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def load_model_on_gpus(model_name, num_gpus):
    # Calculate number of layers to assign to each GPU
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True)
    num_layers = len(model.model.layers)
    layers_per_gpu = num_layers // num_gpus

    # Assign layers to GPUs
    device_map = {}
    for i in range(num_layers):
        device_map[f'model.layers.{i}'] = i // layers_per_gpu
    
    # Assign other components
    device_map['model.embed_tokens'] = 0
    device_map['model.norm'] = num_gpus - 1
    device_map['lm_head'] = num_gpus - 1

    return AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map=device_map,
        torch_dtype=torch.float16,
        trust_remote_code=True
    )

def run_distributed(rank, world_size, model_name):
    setup(rank, world_size)
    
    if rank == 0:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    
    model = load_model_on_gpus(model_name, world_size)
    
    def inference(prompt, image, temperature, beam_size):
        if rank == 0:
            messages = [{"role": "user", "content": f'<image>\n{prompt}'}]
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
            input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0).to(rank)
            image_tensor = model.process_images([image], model.config).to(rank)
        else:
            input_ids = torch.zeros(1, 1, dtype=torch.long).to(rank)
            image_tensor = torch.zeros(1, 3, 224, 224).to(rank)

        dist.broadcast(input_ids, src=0)
        dist.broadcast(image_tensor, src=0)

        with torch.cuda.amp.autocast():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                max_new_tokens=1024,
                temperature=temperature,
                num_beams=beam_size,
                use_cache=True
            )[0]

        if rank == 0:
            return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
        else:
            return ""

    if rank == 0:
        with gr.Blocks() as demo:
            with gr.Row():
                with gr.Column():
                    prompt_input = gr.Textbox(label="Prompt", placeholder="Describe this image in detail")
                    image_input = gr.Image(label="Image", type="pil")
                    temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
                    beam_size_input = gr.Slider(minimum=1, maximum=10, value=4, step=1, label="Beam Size")
                    submit_button = gr.Button("Submit")
                with gr.Column():
                    output_text = gr.Textbox(label="Output")

            submit_button.click(
                fn=inference, 
                inputs=[prompt_input, image_input, temperature_input, beam_size_input], 
                outputs=output_text
            )

        demo.launch(share=True)

    cleanup()

if __name__ == "__main__":
    model_name = 'cognitivecomputations/dolphin-vision-72b'
    world_size = torch.cuda.device_count()
    print(f"Running on {world_size} GPUs")
    torch.multiprocessing.spawn(run_distributed, args=(world_size, model_name), nprocs=world_size, join=True)