File size: 2,538 Bytes
f48ca32
 
 
1f764fa
9afd959
 
f48ca32
9afd959
36176a4
4eafd11
48de1c3
 
 
36176a4
9afd959
023aaad
9afd959
 
 
 
aba5131
9afd959
f48ca32
 
 
 
 
 
9afd959
1f764fa
f48ca32
9afd959
 
023aaad
f48ca32
be7c477
1f764fa
f48ca32
9afd959
 
 
 
 
 
 
 
 
 
 
 
f48ca32
9afd959
be7c477
9afd959
 
 
 
 
1f764fa
be7c477
1f764fa
f48ca32
9afd959
1f764fa
 
f48ca32
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
from huggingface_hub import snapshot_download
from pathlib import Path
import spaces
import subprocess
import os

# Install required packages
subprocess.run('pip install causal-conv1d --no-build-isolation', env={'CAUSAL_CONV1D_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
subprocess.run('pip install mamba-ssm --no-build-isolation', env={'MAMBA_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
subprocess.run('pip install mistral_inference --no-build-isolation', env={'MISTRAL_INFERENCE_SKIP_CUDA_BUILD': "TRUE"}, shell=True)



# Import after installation
from mistral_inference.mamba import Mamba
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest

# Download the model
mistral_models_path = Path.home().joinpath('mistral_models', 'mamba-codestral-7B-v0.1')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistralai/mamba-codestral-7B-v0.1", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], 
                  local_dir=mistral_models_path)

MODEL_PATH = str(mistral_models_path)

# Load model and tokenizer
tokenizer = MistralTokenizer.from_file(os.path.join(MODEL_PATH, "tokenizer.model.v3"))
model = Mamba.from_folder(MODEL_PATH)


@spaces.GPU()
def generate_response(message, history):
    # Convert history to the format expected by the model
    messages = []
    for human, assistant in history:
        messages.append(UserMessage(content=human))
        messages.append(AssistantMessage(content=assistant))
    messages.append(UserMessage(content=message))

    # Create chat completion request
    completion_request = ChatCompletionRequest(messages=messages)
    
    # Tokenize input
    tokens = tokenizer.encode_chat_completion(completion_request).tokens
    
    # Generate response
    out_tokens = generate([tokens], model, max_tokens=256, temperature=0.7, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    
    # Decode response
    result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
    
    return result

    
# Gradio interface
iface = gr.ChatInterface(
    generate_response,
    title="Mamba Codestral Chat (ZeroGPU)",
    description="Chat with the Mamba Codestral 7B model using Hugging Face Spaces ZeroGPU feature.",
)

if __name__ == "__main__":
    iface.launch()