Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import gradio as gr | |
| # !python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" | |
| import json | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer, StoppingCriteria, StoppingCriteriaList, GenerationConfig | |
| # from google.colab import userdata | |
| import os | |
| model_id = "somosnlp/Sam_Diagnostic" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| max_seq_length=2048 | |
| # if torch.cuda.get_device_capability()[0] >= 8: | |
| # # print("Flash Attention") | |
| # attn_implementation="flash_attention_2" | |
| # else: | |
| # attn_implementation=None | |
| attn_implementation=None | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, | |
| max_length = max_seq_length) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, | |
| # quantization_config=bnb_config, | |
| device_map = {"":0}, | |
| attn_implementation = attn_implementation, # A100 o H100 | |
| ).eval() | |
| class ListOfTokensStoppingCriteria(StoppingCriteria): | |
| """ | |
| Clase para definir un criterio de parada basado en una lista de tokens específicos. | |
| """ | |
| def __init__(self, tokenizer, stop_tokens): | |
| self.tokenizer = tokenizer | |
| # Codifica cada token de parada y guarda sus IDs en una lista | |
| self.stop_token_ids_list = [tokenizer.encode(stop_token, add_special_tokens=False) for stop_token in stop_tokens] | |
| def __call__(self, input_ids, scores, **kwargs): | |
| # Verifica si los últimos tokens generados coinciden con alguno de los conjuntos de tokens de parada | |
| for stop_token_ids in self.stop_token_ids_list: | |
| len_stop_tokens = len(stop_token_ids) | |
| if len(input_ids[0]) >= len_stop_tokens: | |
| if input_ids[0, -len_stop_tokens:].tolist() == stop_token_ids: | |
| return True | |
| return False | |
| # Uso del criterio de parada personalizado | |
| stop_tokens = ["<end_of_turn>"] # Lista de tokens de parada | |
| # Inicializa tu criterio de parada con el tokenizer y la lista de tokens de parada | |
| stopping_criteria = ListOfTokensStoppingCriteria(tokenizer, stop_tokens) | |
| # Añade tu criterio de parada a una StoppingCriteriaList | |
| stopping_criteria_list = StoppingCriteriaList([stopping_criteria]) | |
| def generate_text(prompt, idioma_entrada, idioma_salida, max_length=2100): | |
| prompt=prompt.replace(". ", ".\n").strip() | |
| input_text = f'''<bos><start_of_turn>system | |
| You are a helpful AI assistant. | |
| Responde en formato json. | |
| Eres un agente experto en medicina. | |
| Lista de codigos linguisticos disponibles: ["{idioma_entrada}", "{idioma_salida}"]<end_of_turn> | |
| <start_of_turn>user | |
| {prompt}<end_of_turn> | |
| <start_of_turn>model | |
| ''' | |
| inputs = tokenizer.encode(input_text, | |
| return_tensors="pt", | |
| add_special_tokens=False).to("cuda:0") | |
| max_new_tokens=max_length | |
| generation_config = GenerationConfig( | |
| max_new_tokens=max_new_tokens, | |
| temperature=0.35, #55 | |
| #top_p=0.9, | |
| top_k=50, # 45 | |
| repetition_penalty=1., #1.1 | |
| do_sample=True, | |
| ) | |
| outputs = model.generate(generation_config=generation_config, | |
| input_ids=inputs, | |
| stopping_criteria=stopping_criteria_list,) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=False) #True | |
| def mostrar_respuesta(pregunta, idioma_entrada, idioma_salida): | |
| try: | |
| lista_codigo_lin = { | |
| "español": "es", | |
| "ingles": "en", | |
| } | |
| # Utiliza los parámetros de idioma para obtener los códigos de idioma correspondientes. | |
| codigo_lin_entrada = lista_codigo_lin[idioma_entrada.lower()] | |
| codigo_lin_salida = lista_codigo_lin[idioma_salida.lower()] | |
| res= generate_text(pregunta, codigo_lin_entrada, codigo_lin_salida, max_length=1500) | |
| inicio_json = res.find('{') | |
| fin_json = res.rfind('}') + 1 | |
| json_str = res[inicio_json:fin_json] | |
| json_obj = json.loads(json_str) | |
| return json_obj["description"], json_obj["medical_specialty"], json_obj["principal_diagnostic"] | |
| except: | |
| json_obj={} | |
| json_obj['description']='Error diagnostico' | |
| json_obj['medical_specialty']='Error diagnostico' | |
| json_obj['principal_diagnostic']='Error diagnostico' | |
| return json_obj["description"], json_obj["medical_specialty"], json_obj["principal_diagnostic"] | |
| # Ejemplos de preguntas | |
| ejemplos = [ | |
| ["CHIEF COMPLAINT:, Left wrist pain.,HISTORY OF PRESENT PROBLEM"], | |
| ["INDICATIONS: ,Chest pain.,STRESS TECHNIQUE:,"], | |
| ["MOTIVO DE CONSULTA: Una niña de 2 meses"], | |
| ] | |
| idiomas = ["español", "ingles"] | |
| iface = gr.Interface( | |
| fn=mostrar_respuesta, | |
| inputs=[ | |
| gr.Textbox(label="Pregunta", placeholder="Introduce tu consulta médica aquí..."), | |
| gr.Dropdown(label="Idioma de Entrada", choices=idiomas, default="español"), | |
| gr.Dropdown(label="Idioma de Salida", choices=idiomas, default="español"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Description", lines=2), | |
| gr.Textbox(label="Medical specialty", lines=1), | |
| gr.Textbox(label="Principal diagnostic", lines=1) | |
| ], | |
| title="Consultas medicas", | |
| description="Introduce tu diagnostico.", | |
| examples=ejemplos, | |
| concurrency_limit=20 | |
| ) | |
| iface.queue(max_size=14).launch(share=True,debug=True, ) # share=True,debug=True | |