File size: 2,437 Bytes
f48ca32
 
 
1f764fa
9afd959
 
f48ca32
9afd959
 
36176a4
 
9afd959
 
 
 
 
 
aba5131
9afd959
f48ca32
 
 
 
 
 
9afd959
1f764fa
f48ca32
9afd959
 
 
f48ca32
1f764fa
f48ca32
9afd959
 
 
 
 
 
 
 
 
 
 
 
f48ca32
9afd959
 
 
 
 
 
 
1f764fa
 
f48ca32
9afd959
1f764fa
 
f48ca32
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from huggingface_hub import snapshot_download
from pathlib import Path
import spaces
import subprocess
import os

# Install required packages
subprocess.run('pip install mistral_inference mamba-ssm --no-build-isolation', env={'MAMBA_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
subprocess.run('pip install causal-conv1d --no-build-isolation', env={'CAUSAL_CONV1D_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# Import after installation
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest

# Download the model
mistral_models_path = Path.home().joinpath('mistral_models', 'mamba-codestral-7B-v0.1')
mistral_models_path.mkdir(parents=True, exist_ok=True)

snapshot_download(repo_id="mistralai/mamba-codestral-7B-v0.1", 
                  allow_patterns=["params.json", "consolidated.safetensors", "tokenizer.model.v3"], 
                  local_dir=mistral_models_path)

MODEL_PATH = str(mistral_models_path)

# Load model and tokenizer
tokenizer = MistralTokenizer.from_file(os.path.join(MODEL_PATH, "tokenizer.model.v3"))
model = Transformer.from_folder(MODEL_PATH)

@spaces.GPU()
def generate_response(message, history):
    # Convert history to the format expected by the model
    messages = []
    for human, assistant in history:
        messages.append(UserMessage(content=human))
        messages.append(AssistantMessage(content=assistant))
    messages.append(UserMessage(content=message))

    # Create chat completion request
    completion_request = ChatCompletionRequest(messages=messages)
    
    # Tokenize input
    tokens = tokenizer.encode_chat_completion(completion_request).tokens
    
    # Generate response
    out_tokens, * = generate([tokens], model, max_tokens=256, temperature=0.7, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    
    # Decode response
    result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
    
    return result

# Gradio interface
iface = gr.ChatInterface(
    generate_response,
    title="Mamba Codestral Chat (ZeroGPU)",
    description="Chat with the Mamba Codestral 7B model using Hugging Face Spaces ZeroGPU feature.",
)

if __name__ == "__main__":
    iface.launch()