\n",
"Labels shape: torch.Size([2, 512])\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [39/39 02:15, Epoch 3/3]\n",
"
\n",
" \n",
" \n",
" \n",
" Step | \n",
" Training Loss | \n",
"
\n",
" \n",
" \n",
" \n",
" 5 | \n",
" 12.130200 | \n",
"
\n",
" \n",
" 10 | \n",
" 3.432800 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.502100 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.297100 | \n",
"
\n",
" \n",
" 25 | \n",
" 0.232200 | \n",
"
\n",
" \n",
" 30 | \n",
" 0.199000 | \n",
"
\n",
" \n",
" 35 | \n",
" 0.174700 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=39, training_loss=2.1929761950786295, metrics={'train_runtime': 140.2841, 'train_samples_per_second': 1.112, 'train_steps_per_second': 0.278, 'total_flos': 3409289020440576.0, 'train_loss': 2.1929761950786295, 'epoch': 3.0})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import TrainingArguments, Trainer\n",
"\n",
"# 1. Configurar formato del dataset como tensores\n",
"tokenized_dataset.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\"])\n",
"\n",
"# 2. Data collator mejorado\n",
"def custom_collator(features):\n",
" return {\n",
" \"input_ids\": torch.stack([torch.tensor(f[\"input_ids\"]) for f in features]),\n",
" \"attention_mask\": torch.stack([torch.tensor(f[\"attention_mask\"]) for f in features]),\n",
" \"labels\": torch.stack([torch.tensor(f[\"input_ids\"]) for f in features])\n",
" }\n",
"\n",
"# 3. Configurar argumentos con parámetros faltantes\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./html5-lora\",\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=2, # Reducir para ahorrar memoria\n",
" num_train_epochs=3,\n",
" learning_rate=3e-4,\n",
" fp16=torch.cuda.is_available(),\n",
" logging_steps=5,\n",
" report_to=\"none\",\n",
" remove_unused_columns=False, # Necesario para LoRA\n",
" label_names=[\"labels\"] # Añadir parámetro faltante\n",
")\n",
"\n",
"# 4. Crear Trainer con parámetros actualizados\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_dataset[\"train\"],\n",
" eval_dataset=tokenized_dataset[\"test\"],\n",
" data_collator=custom_collator\n",
")\n",
"\n",
"# 5. Verificación adicional\n",
"sample_batch = next(iter(trainer.get_train_dataloader()))\n",
"print(\"\\nVerificación de batch:\")\n",
"print(f\"Input ids type: {type(sample_batch['input_ids'][0])}\")\n",
"print(f\"Labels shape: {sample_batch['labels'].shape}\")\n",
"\n",
"# 6. Iniciar entrenamiento\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hm89m0JCtYnY"
},
"source": [
"### Generación de Respuestas"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rukjNbmftfCv",
"outputId": "e7c3781f-1a33-4a43-9c8c-4eb1e68589e2"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Device set to use cuda:0\n",
"The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'GraniteMoeSharedForCausalLM', 'HeliumForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MllamaForCausalLM', 'MoshiForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'NemotronForCausalLM', 'OlmoForCausalLM', 'Olmo2ForCausalLM', 'OlmoeForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'PhimoeForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM', 'ZambaForCausalLM', 'Zamba2ForCausalLM'].\n"
]
}
],
"source": [
"from transformers import pipeline\n",
"chatbot = pipeline(\n",
" \"text-generation\",\n",
" model = model,\n",
" tokenizer = tokenizer,\n",
" torch_dtype = torch.float16\n",
")\n",
"\n",
"def generate_response(query):\n",
" prompt = f\"[INST] Pregunta HTML5: {query} [/INST]\"\n",
" response = chatbot(\n",
" prompt,\n",
" max_new_tokens = 200,\n",
" temperature = 0.3,\n",
" do_sample = True,\n",
" pad_token_id = tokenizer.eos_token_id\n",
" )\n",
" return response[0]['generated_text'].split(\"[/INST]\")[-1].strip()\n",
"\n",
"\n",
"\n",
"def generate_response_gradio(query):\n",
" try:\n",
" # Manejar casos no técnicos primero\n",
" if query.lower().strip() in [\"hola\", \"hi\", \"ayuda\"]:\n",
" return \"¡Hola! Soy un asistente de HTML5. Ejemplo: '¿Cómo usar