|
60 | 60 | }, |
61 | 61 | { |
62 | 62 | "cell_type": "code", |
| 63 | + "execution_count": null, |
| 64 | + "metadata": { |
| 65 | + "id": "R4FrgYBWGdsJ" |
| 66 | + }, |
| 67 | + "outputs": [], |
63 | 68 | "source": [ |
64 | 69 | "#@title Install dependencies\n", |
65 | 70 | "! pip install --no-deps -U flax\n", |
66 | 71 | "! pip install jaxtyping kagglehub treescope" |
67 | | - ], |
68 | | - "metadata": { |
69 | | - "id": "R4FrgYBWGdsJ" |
70 | | - }, |
71 | | - "execution_count": null, |
72 | | - "outputs": [] |
| 72 | + ] |
73 | 73 | }, |
74 | 74 | { |
75 | 75 | "cell_type": "markdown", |
| 76 | + "metadata": { |
| 77 | + "id": "A0nmk8NwG4fK" |
| 78 | + }, |
76 | 79 | "source": [ |
77 | 80 | "To interact with the Gemma model, you will use the Flax NNX gemma code from\n", |
78 | 81 | "google/flax examples on GitHub. Since it is not exposed as a package, you need\n", |
79 | 82 | "to use the following workaround to import from the Flax NNX examples/gemma on\n", |
80 | 83 | "GitHub." |
81 | | - ], |
82 | | - "metadata": { |
83 | | - "id": "A0nmk8NwG4fK" |
84 | | - } |
| 84 | + ] |
85 | 85 | }, |
86 | 86 | { |
87 | 87 | "cell_type": "code", |
|
121 | 121 | }, |
122 | 122 | { |
123 | 123 | "cell_type": "markdown", |
| 124 | + "metadata": { |
| 125 | + "id": "8xUHvwdHH0O1" |
| 126 | + }, |
124 | 127 | "source": [ |
125 | 128 | "To use Gemma model, you’ll need a Kaggle account and API key:\n", |
126 | 129 | "\n", |
|
135 | 138 | "\n", |
136 | 139 | "4. Request access to the model here:\n", |
137 | 140 | " https://www.kaggle.com/models/google/gemma-3" |
138 | | - ], |
139 | | - "metadata": { |
140 | | - "id": "8xUHvwdHH0O1" |
141 | | - } |
| 141 | + ] |
142 | 142 | }, |
143 | 143 | { |
144 | 144 | "cell_type": "code", |
|
275 | 275 | }, |
276 | 276 | { |
277 | 277 | "cell_type": "code", |
278 | | - "source": [ |
279 | | - "prompt_length, out_length, out_data = run_model(\"What is the capital of Switzerland? Answer:\")" |
280 | | - ], |
| 278 | + "execution_count": null, |
281 | 279 | "metadata": { |
282 | 280 | "id": "i8UVzaiia1pR" |
283 | 281 | }, |
284 | | - "execution_count": null, |
285 | | - "outputs": [] |
| 282 | + "outputs": [], |
| 283 | + "source": [ |
| 284 | + "prompt_length, out_length, out_data = run_model(\"What is the capital of Switzerland? Answer:\")" |
| 285 | + ] |
286 | 286 | }, |
287 | 287 | { |
288 | 288 | "cell_type": "markdown", |
|
350 | 350 | "cell_type": "code", |
351 | 351 | "execution_count": null, |
352 | 352 | "metadata": { |
353 | | - "id": "N7CMfuIwnR1B", |
354 | | - "cellView": "form" |
| 353 | + "cellView": "form", |
| 354 | + "id": "N7CMfuIwnR1B" |
355 | 355 | }, |
356 | 356 | "outputs": [], |
357 | 357 | "source": [ |
|
435 | 435 | }, |
436 | 436 | { |
437 | 437 | "cell_type": "markdown", |
438 | | - "source": [ |
439 | | - "The next cell will visualize the top activated neurons in a given layer. You can hover over the colored blocks to get the neuron id and value. For the last feedfordward layer _25_ you should see that the top activated neurons are *1937* and *4422*." |
440 | | - ], |
441 | 438 | "metadata": { |
442 | 439 | "id": "c9hYsqK-CTVW" |
443 | | - } |
| 440 | + }, |
| 441 | + "source": [ |
| 442 | + "The next cell will visualize the top activated neurons in a given layer. You can hover over the colored blocks to get the neuron id and value. For the last feedfordward layer _25_ you should see that the top activated neurons are *1937* and *4422*." |
| 443 | + ] |
444 | 444 | }, |
445 | 445 | { |
446 | 446 | "cell_type": "code", |
|
487 | 487 | }, |
488 | 488 | { |
489 | 489 | "cell_type": "markdown", |
| 490 | + "metadata": { |
| 491 | + "id": "30A0KOHTCiVB" |
| 492 | + }, |
490 | 493 | "source": [ |
491 | 494 | "After identifiying a neuron of interest we can deactive or boost it. You can play around with the bias values to achieve different behaviours.\n", |
492 | 495 | "1. **Deactivate** neuron 1937 in layer 25 and issue the same prompt again.\n", |
493 | | - "2. **Boost** neuron 1937 in layer 25 and your model should repsonse with *Switzerland* significantly more often.\n", |
494 | | - "\n" |
495 | | - ], |
496 | | - "metadata": { |
497 | | - "id": "30A0KOHTCiVB" |
498 | | - } |
| 496 | + "2. **Boost** neuron 1937 in layer 25 and your model should repsonse with *Switzerland* significantly more often.\n" |
| 497 | + ] |
499 | 498 | }, |
500 | 499 | { |
501 | 500 | "cell_type": "code", |
|
531 | 530 | "sampler = sampler_lib.Sampler(\n", |
532 | 531 | " transformer=transformer,\n", |
533 | 532 | " vocab=vocab,\n", |
534 | | - ")\n", |
535 | | - "\n" |
| 533 | + ")\n" |
536 | 534 | ] |
537 | 535 | }, |
538 | 536 | { |
539 | 537 | "cell_type": "markdown", |
540 | | - "source": [ |
541 | | - "We can also apply the `premature_decode` functions to the value of a neuron to investigate the effect of a neuron. Check neuron *1937* of layer *25* to verify why the model behaviour changed by deactivating or boosting the neuron." |
542 | | - ], |
543 | 538 | "metadata": { |
544 | 539 | "id": "CNWJ-8YzDCxw" |
545 | | - } |
| 540 | + }, |
| 541 | + "source": [ |
| 542 | + "We can also apply the `premature_decode` functions to the value of a neuron to investigate the effect of a neuron. Check neuron *1937* of layer *25* to verify why the model behaviour changed by deactivating or boosting the neuron." |
| 543 | + ] |
546 | 544 | }, |
547 | 545 | { |
548 | 546 | "cell_type": "code", |
|
585 | 583 | "cell_type": "code", |
586 | 584 | "execution_count": null, |
587 | 585 | "metadata": { |
588 | | - "id": "2Cjc8YMM-fov", |
589 | | - "cellView": "form" |
| 586 | + "cellView": "form", |
| 587 | + "id": "2Cjc8YMM-fov" |
590 | 588 | }, |
591 | 589 | "outputs": [], |
592 | 590 | "source": [ |
|
699 | 697 | } |
700 | 698 | ], |
701 | 699 | "metadata": { |
| 700 | + "accelerator": "GPU", |
702 | 701 | "colab": { |
703 | | - "private_outputs": true, |
704 | | - "provenance": [], |
705 | | - "gpuType": "A100", |
706 | | - "machine_shape": "hm" |
| 702 | + "name": "[Gemma_3]Activation_Hacking.ipynb", |
| 703 | + "toc_visible": true |
707 | 704 | }, |
708 | 705 | "kernelspec": { |
709 | 706 | "display_name": "Python 3", |
710 | 707 | "name": "python3" |
711 | | - }, |
712 | | - "language_info": { |
713 | | - "name": "python" |
714 | | - }, |
715 | | - "accelerator": "GPU" |
| 708 | + } |
716 | 709 | }, |
717 | 710 | "nbformat": 4, |
718 | 711 | "nbformat_minor": 0 |
|
0 commit comments